[Mlir-commits] [mlir] c4c7e06 - [MLIR][OpenMP] Shifted hint from CriticalOp to CriticalDeclareOp

Shraiysh Vaishay llvmlistbot at llvm.org
Wed Oct 20 09:06:19 PDT 2021


Author: Shraiysh Vaishay
Date: 2021-10-20T21:36:09+05:30
New Revision: c4c7e06bd700aeccfbe5c1f075bd5897f54b68f2

URL: https://github.com/llvm/llvm-project/commit/c4c7e06bd700aeccfbe5c1f075bd5897f54b68f2
DIFF: https://github.com/llvm/llvm-project/commit/c4c7e06bd700aeccfbe5c1f075bd5897f54b68f2.diff

LOG: [MLIR][OpenMP] Shifted hint from CriticalOp to CriticalDeclareOp

According to the OpenMP 5.0 standard, names and hints of critical operation are
closely related. The following are the restrictions on them:
 - Unless the effect is as if `hint(omp_sync_hint_none)` was specified, the
   critical construct must specify a name.
 - If the hint clause is specified, each of the critical constructs with the
   same name must have a hint clause for which the hint-expression evaluates to
   the same value.

These restrictions will be enforced by design if the hint expression is a part
of the `omp.critical.declare` operation.
 - Any operation with no "name" will be considered to have
   `hint(omp_sync_hint_none)`.
 - All the operations with the same "name" will have the same hint value.

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D112134

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir
    mlir/test/Target/LLVMIR/openmp-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 37dcaf04c0089..e15bbda9bbe98 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -362,9 +362,14 @@ def CriticalDeclareOp : OpenMP_Op<"critical.declare", [Symbol]> {
     The name can be used in critical constructs in the dialect.
   }];
 
-  let arguments = (ins SymbolNameAttr:$sym_name);
+  let arguments = (ins SymbolNameAttr:$sym_name,
+                       DefaultValuedAttr<I64Attr, "0">:$hint);
+
+  let assemblyFormat = [{
+    $sym_name custom<SynchronizationHint>($hint) attr-dict
+  }];
 
-  let assemblyFormat = "$sym_name attr-dict";
+  let verifier = "return verifyCriticalDeclareOp(*this);";
 }
 
 
@@ -375,13 +380,12 @@ def CriticalOp : OpenMP_Op<"critical"> {
     block (region) to be executed by only a single thread at a time.
   }];
 
-  let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$name,
-                       DefaultValuedAttr<I64Attr, "0">:$hint);
+  let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$name);
 
   let regions = (region AnyRegion:$region);
 
   let assemblyFormat = [{
-    (`(` $name^ `)`)? custom<SynchronizationHint>($hint) $region attr-dict
+    (`(` $name^ `)`)? $region attr-dict
   }];
 
   let verifier = "return ::verifyCriticalOp(*this);";

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 5a73850042c06..2478713211dc7 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1110,14 +1110,11 @@ static LogicalResult verifyWsLoopOp(WsLoopOp op) {
 // Verifier for critical construct (2.17.1)
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verifyCriticalOp(CriticalOp op) {
+static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) {
+  return verifySynchronizationHint(op, op.hint());
+}
 
-  if (failed(verifySynchronizationHint(op, op.hint()))) {
-    return failure();
-  }
-  if (!op.name().hasValue() && (op.hint() != 0))
-    return op.emitOpError() << "must specify a name unless the effect is as if "
-                               "no hint is specified";
+static LogicalResult verifyCriticalOp(CriticalOp op) {
 
   if (op.nameAttr()) {
     auto symbolRef = op.nameAttr().cast<SymbolRefAttr>();

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index f91c02cf9956b..dbee4452281a4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -300,8 +300,19 @@ convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(
       builder.saveIP(), builder.getCurrentDebugLocation());
   llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
-  llvm::Constant *hint = llvm::ConstantInt::get(
-      llvm::Type::getInt32Ty(llvmContext), static_cast<int>(criticalOp.hint()));
+  llvm::Constant *hint = nullptr;
+
+  // If it has a name, it probably has a hint too.
+  if (criticalOp.nameAttr()) {
+    // The verifiers in OpenMP Dialect guarentee that all the pointers are
+    // non-null
+    auto symbolRef = criticalOp.nameAttr().cast<SymbolRefAttr>();
+    auto criticalDeclareOp =
+        SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
+                                                                     symbolRef);
+    hint = llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
+                                  static_cast<int>(criticalDeclareOp.hint()));
+  }
   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical(
       ompLoc, bodyGenCB, finiCB, criticalOp.name().getValueOr(""), hint));
   return success();

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 36d40ad455d06..f4637f72ecdd1 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -296,19 +296,9 @@ func @foo(%lb : index, %ub : index, %step : index, %mem : memref<1xf32>) {
 
 // -----
 
-func @omp_critical1() -> () {
-  // expected-error @below {{must specify a name unless the effect is as if no hint is specified}}
-  omp.critical hint(nonspeculative) {
-    omp.terminator
-  }
-  return
-}
-
-// -----
-
 func @omp_critical2() -> () {
   // expected-error @below {{expected symbol reference @excl to point to a critical declaration}}
-  omp.critical(@excl) hint(speculative) {
+  omp.critical(@excl) {
     omp.terminator
   }
   return
@@ -316,32 +306,15 @@ func @omp_critical2() -> () {
 
 // -----
 
-omp.critical.declare @mutex
-func @omp_critical() -> () {
-  // expected-error @below {{the hints omp_sync_hint_uncontended and omp_sync_hint_contended cannot be combined}}
-  omp.critical(@mutex) hint(uncontended, contended) {
-    omp.terminator
-  }
-  return
-}
+// expected-error @below {{the hints omp_sync_hint_uncontended and omp_sync_hint_contended cannot be combined}}
+omp.critical.declare @mutex hint(uncontended, contended)
 
 // -----
 
-omp.critical.declare @mutex
-func @omp_critical() -> () {
-  // expected-error @below {{the hints omp_sync_hint_nonspeculative and omp_sync_hint_speculative cannot be combined}}
-  omp.critical(@mutex) hint(nonspeculative, speculative) {
-    omp.terminator
-  }
-  return
-}
+// expected-error @below {{the hints omp_sync_hint_nonspeculative and omp_sync_hint_speculative cannot be combined}}
+omp.critical.declare @mutex hint(nonspeculative, speculative)
 
 // -----
 
-omp.critical.declare @mutex
-func @omp_critica() -> () {
-  // expected-error @below {{invalid_hint is not a valid hint}}
-  omp.critical(@mutex) hint(invalid_hint) {
-    omp.terminator
-  }
-}
+// expected-error @below {{invalid_hint is not a valid hint}}
+omp.critical.declare @mutex hint(invalid_hint)

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 38ce6c2459395..aa6e07719336a 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -369,9 +369,23 @@ func @reduction2(%lb : index, %ub : index, %step : index) {
   return
 }
 
-// CHECK: omp.critical.declare
-// CHECK-LABEL: @mutex
-omp.critical.declare @mutex
+// CHECK: omp.critical.declare @mutex1 hint(uncontended)
+omp.critical.declare @mutex1 hint(uncontended)
+// CHECK: omp.critical.declare @mutex2 hint(contended)
+omp.critical.declare @mutex2 hint(contended)
+// CHECK: omp.critical.declare @mutex3 hint(nonspeculative)
+omp.critical.declare @mutex3 hint(nonspeculative)
+// CHECK: omp.critical.declare @mutex4 hint(speculative)
+omp.critical.declare @mutex4 hint(speculative)
+// CHECK: omp.critical.declare @mutex5 hint(uncontended, nonspeculative)
+omp.critical.declare @mutex5 hint(uncontended, nonspeculative)
+// CHECK: omp.critical.declare @mutex6 hint(contended, nonspeculative)
+omp.critical.declare @mutex6 hint(contended, nonspeculative)
+// CHECK: omp.critical.declare @mutex7 hint(uncontended, speculative)
+omp.critical.declare @mutex7 hint(uncontended, speculative)
+// CHECK: omp.critical.declare @mutex8 hint(contended, speculative)
+omp.critical.declare @mutex8 hint(contended, speculative)
+
 
 // CHECK-LABEL: omp_critical
 func @omp_critical() -> () {
@@ -380,36 +394,8 @@ func @omp_critical() -> () {
     omp.terminator
   }
 
-  // CHECK: omp.critical(@{{.*}}) hint(uncontended)
-  omp.critical(@mutex) hint(uncontended) {
-    omp.terminator
-  }
-  // CHECK: omp.critical(@{{.*}}) hint(contended)
-  omp.critical(@mutex) hint(contended) {
-    omp.terminator
-  }
-  // CHECK: omp.critical(@{{.*}}) hint(nonspeculative)
-  omp.critical(@mutex) hint(nonspeculative) {
-    omp.terminator
-  }
-  // CHECK: omp.critical(@{{.*}}) hint(uncontended, nonspeculative)
-  omp.critical(@mutex) hint(uncontended, nonspeculative) {
-    omp.terminator
-  }
-  // CHECK: omp.critical(@{{.*}}) hint(contended, nonspeculative)
-  omp.critical(@mutex) hint(nonspeculative, contended) {
-    omp.terminator
-  }
-  // CHECK: omp.critical(@{{.*}}) hint(speculative)
-  omp.critical(@mutex) hint(speculative) {
-    omp.terminator
-  }
-  // CHECK: omp.critical(@{{.*}}) hint(uncontended, speculative)
-  omp.critical(@mutex) hint(uncontended, speculative) {
-    omp.terminator
-  }
-  // CHECK: omp.critical(@{{.*}}) hint(contended, speculative)
-  omp.critical(@mutex) hint(speculative, contended) {
+  // CHECK: omp.critical(@{{.*}})
+  omp.critical(@mutex1) {
     omp.terminator
   }
   return

diff  --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 6027749ba2d3a..7dfeba573af45 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -469,11 +469,11 @@ llvm.func @test_omp_wsloop_guided(%lb : i64, %ub : i64, %step : i64) -> () {
 
 // -----
 
-omp.critical.declare @mutex
+omp.critical.declare @mutex hint(contended)
 
 // CHECK-LABEL: @omp_critical
 llvm.func @omp_critical(%x : !llvm.ptr<i32>, %xval : i32) -> () {
-  // CHECK: call void @__kmpc_critical_with_hint({{.*}}critical_user_.var{{.*}}, i32 0)
+  // CHECK: call void @__kmpc_critical({{.*}}critical_user_.var{{.*}})
   // CHECK: br label %omp.critical.region
   // CHECK: omp.critical.region
   omp.critical {
@@ -486,7 +486,7 @@ llvm.func @omp_critical(%x : !llvm.ptr<i32>, %xval : i32) -> () {
   // CHECK: call void @__kmpc_critical_with_hint({{.*}}critical_user_mutex.var{{.*}}, i32 2)
   // CHECK: br label %omp.critical.region
   // CHECK: omp.critical.region
-  omp.critical(@mutex) hint(contended) {
+  omp.critical(@mutex) {
   // CHECK: store
     llvm.store %xval, %x : !llvm.ptr<i32>
     omp.terminator


        


More information about the Mlir-commits mailing list