[Mlir-commits] [mlir] 59989d6 - [MLIR][OpenMP] Add support for critical construct
Kiran Chandramohan
llvmlistbot at llvm.org
Tue Aug 3 02:50:48 PDT 2021
Author: Kiran Chandramohan
Date: 2021-08-03T10:50:21+01:00
New Revision: 59989d68ba065b8dc1909d525dfd135d9e3c0206
URL: https://github.com/llvm/llvm-project/commit/59989d68ba065b8dc1909d525dfd135d9e3c0206
DIFF: https://github.com/llvm/llvm-project/commit/59989d68ba065b8dc1909d525dfd135d9e3c0206.diff
LOG: [MLIR][OpenMP] Add support for critical construct
This patch adds the critical construct to the OpenMP dialect. The
implementation models the definition in 2.17.1 of the OpenMP 5 standard.
A name and hint can be specified. The name is a global entity or has
external linkage, it is modelled as a FlatSymbolRefAttr. Hint is
modelled as an integer enum attribute.
Also lowering to LLVM IR using the OpenMP IRBuilder.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D107135
Added:
Modified:
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/test/Dialect/OpenMP/invalid.mlir
mlir/test/Dialect/OpenMP/ops.mlir
mlir/test/Target/LLVMIR/openmp-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 6cab36ddc570a..596eb1b8ed628 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -349,6 +349,43 @@ def MasterOp : OpenMP_Op<"master"> {
let assemblyFormat = "$region attr-dict";
}
+//===----------------------------------------------------------------------===//
+// 2.17.1 critical Construct
+//===----------------------------------------------------------------------===//
+// TODO: Autogenerate this from OMP.td in llvm/include/Frontend
+def omp_sync_hint_none: I32EnumAttrCase<"none", 0>;
+def omp_sync_hint_uncontended: I32EnumAttrCase<"uncontended", 1>;
+def omp_sync_hint_contended: I32EnumAttrCase<"contended", 2>;
+def omp_sync_hint_nonspeculative: I32EnumAttrCase<"nonspeculative", 3>;
+def omp_sync_hint_speculative: I32EnumAttrCase<"speculative", 4>;
+
+def SyncHintKind: I32EnumAttr<"SyncHintKind", "OpenMP Sync Hint Kind",
+ [omp_sync_hint_none, omp_sync_hint_uncontended, omp_sync_hint_contended,
+ omp_sync_hint_nonspeculative, omp_sync_hint_speculative]> {
+ let cppNamespace = "::mlir::omp";
+ let stringToSymbolFnName = "ConvertToEnum";
+ let symbolToStringFnName = "ConvertToString";
+}
+
+def CriticalOp : OpenMP_Op<"critical"> {
+ let summary = "critical construct";
+ let description = [{
+ The critical construct imposes a restriction on the associated structured
+ block (region) to be executed by only a single thread at a time.
+ }];
+
+ let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$name,
+ OptionalAttr<SyncHintKind>:$hint);
+
+ let regions = (region AnyRegion:$region);
+
+ let assemblyFormat = [{
+ (`(` $name^ `)`)? (`hint` `(` $hint^ `)`)? $region attr-dict
+ }];
+
+ let verifier = "return ::verifyCriticalOp(*this);";
+}
+
//===----------------------------------------------------------------------===//
// 2.17.2 barrier Construct
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index b5abdc7426ac5..314fa349e96cf 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -974,5 +974,13 @@ static LogicalResult verifyWsLoopOp(WsLoopOp op) {
return success();
}
+static LogicalResult verifyCriticalOp(CriticalOp op) {
+ if (!op.name().hasValue() && op.hint().hasValue() &&
+ (op.hint().getValue() != SyncHintKind::none))
+ return op.emitOpError() << "must specify a name unless the effect is as if "
+ "hint(none) is specified";
+ return success();
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 6259612d5112b..686386a3542a9 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -204,6 +204,45 @@ convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}
+/// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
+static LogicalResult
+convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
+ auto criticalOp = cast<omp::CriticalOp>(opInst);
+ // TODO: support error propagation in OpenMPIRBuilder and use it instead of
+ // relying on captured variables.
+ LogicalResult bodyGenStatus = success();
+
+ auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
+ llvm::BasicBlock &continuationBlock) {
+ // CriticalOp has only one region associated with it.
+ auto ®ion = cast<omp::CriticalOp>(opInst).getRegion();
+ convertOmpOpRegions(region, "omp.critical.region", *codeGenIP.getBlock(),
+ continuationBlock, builder, moduleTranslation,
+ bodyGenStatus);
+ };
+
+ // TODO: Perform finalization actions for variables. This has to be
+ // called for variables which have destructors/finalizers.
+ auto finiCB = [&](InsertPointTy codeGenIP) {};
+
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc(
+ builder.saveIP(), builder.getCurrentDebugLocation());
+ llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
+ llvm::Constant *hint = nullptr;
+ if (criticalOp.hint().hasValue()) {
+ hint =
+ llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
+ static_cast<int>(criticalOp.hint().getValue()));
+ } else {
+ hint = llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
+ }
+ builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical(
+ ompLoc, bodyGenCB, finiCB, criticalOp.name().getValueOr(""), hint));
+ return success();
+}
+
/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -365,6 +404,9 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
.Case([&](omp::MasterOp) {
return convertOmpMaster(*op, builder, moduleTranslation);
})
+ .Case([&](omp::CriticalOp) {
+ return convertOmpCritical(*op, builder, moduleTranslation);
+ })
.Case([&](omp::WsLoopOp) {
return convertOmpWsLoop(*op, builder, moduleTranslation);
})
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 4c85025c65a9d..a755f18b9250a 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -293,3 +293,13 @@ func @foo(%lb : index, %ub : index, %step : index, %mem : memref<1xf32>) {
}
return
}
+
+// -----
+
+func @omp_critical() -> () {
+ // expected-error @below {{must specify a name unless the effect is as if hint(none) is specified}}
+ omp.critical hint(nonspeculative) {
+ omp.terminator
+ }
+ return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 35ac6b30593b2..2e4f5335c30f6 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -369,3 +369,14 @@ func @reduction2(%lb : index, %ub : index, %step : index) {
return
}
+// CHECK-LABEL: omp_critical
+func @omp_critical() -> () {
+ omp.critical {
+ omp.terminator
+ }
+
+ omp.critical(@mutex) hint(nonspeculative) {
+ omp.terminator
+ }
+ return
+}
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index e53ec47370eb8..51cefc1ae2ebb 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -466,3 +466,28 @@ llvm.func @test_omp_wsloop_guided(%lb : i64, %ub : i64, %step : i64) -> () {
}
llvm.return
}
+
+// CHECK-LABEL: @omp_critical
+llvm.func @omp_critical(%x : !llvm.ptr<i32>, %xval : i32) -> () {
+ // CHECK: call void @__kmpc_critical_with_hint({{.*}}critical_user_.var{{.*}}, i32 0)
+ // CHECK: br label %omp.critical.region
+ // CHECK: omp.critical.region
+ omp.critical {
+ // CHECK: store
+ llvm.store %xval, %x : !llvm.ptr<i32>
+ omp.terminator
+ }
+ // CHECK: call void @__kmpc_end_critical({{.*}}critical_user_.var{{.*}})
+
+ // CHECK: call void @__kmpc_critical_with_hint({{.*}}critical_user_mutex.var{{.*}}, i32 2)
+ // CHECK: br label %omp.critical.region
+ // CHECK: omp.critical.region
+ omp.critical(@mutex) hint(contended) {
+ // CHECK: store
+ llvm.store %xval, %x : !llvm.ptr<i32>
+ omp.terminator
+ }
+ // CHECK: call void @__kmpc_end_critical({{.*}}critical_user_mutex.var{{.*}})
+
+ llvm.return
+}
More information about the Mlir-commits
mailing list