[Mlir-commits] [mlir] [OpenMP]Support for lowering masked op (PR #98401)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 10 15:13:57 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-flang-openmp

Author: Anchu Rajendran S (anchuraj)

<details>
<summary>Changes</summary>

Change adds the lowering support for Omp masked directive.  Omp masked is introduced in 5.2 standard and allows a parallel region to be executed by threads specified by a programmer. This is achieved with the help of filter clause which helps to specify thread id expected to execute the region.

Other previous related PRs: 
- [Fortran Parsing and Semantic Support](https://github.com/llvm/llvm-project/pull/91432)
- [MLIR Support](https://github.com/llvm/llvm-project/pull/96022/files)


---
Full diff: https://github.com/llvm/llvm-project/pull/98401.diff


2 Files Affected:

- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+40) 
- (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+29) 


``````````diff
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 0c9c699a1f390..591900723e2e9 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -264,6 +264,43 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
   llvm_unreachable("Unknown ClauseProcBindKind kind");
 }
 
+/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
+static LogicalResult
+convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
+                 LLVM::ModuleTranslation &moduleTranslation) {
+  auto maskedOp = cast<omp::MaskedOp>(opInst);
+  using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
+  // TODO: support error propagation in OpenMPIRBuilder and use it instead of
+  // relying on captured variables.
+  LogicalResult bodyGenStatus = success();
+
+  auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
+    // MaskedOp has only one region associated with it.
+    auto &region = maskedOp.getRegion();
+    builder.restoreIP(codeGenIP);
+    convertOmpOpRegions(region, "omp.masked.region", builder, moduleTranslation,
+                        bodyGenStatus);
+  };
+
+  // TODO: Perform finalization actions for variables. This has to be
+  // called for variables which have destructors/finalizers.
+  auto finiCB = [&](InsertPointTy codeGenIP) {};
+
+  llvm::Value *filterVal = nullptr;
+  if (auto filterVar = maskedOp.getFilteredThreadId()) {
+    filterVal = moduleTranslation.lookupValue(filterVar);
+  } else {
+    llvm::LLVMContext &llvmContext = builder.getContext();
+    filterVal =
+        llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
+  }
+  assert(filterVal != nullptr);
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createMasked(
+      ompLoc, bodyGenCB, finiCB, filterVal));
+  return success();
+}
+
 /// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -3330,6 +3367,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
       .Case([&](omp::ParallelOp op) {
         return convertOmpParallel(op, builder, moduleTranslation);
       })
+      .Case([&](omp::MaskedOp) {
+        return convertOmpMasked(*op, builder, moduleTranslation);
+      })
       .Case([&](omp::MasterOp) {
         return convertOmpMaster(*op, builder, moduleTranslation);
       })
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index dfeaf4be33adb..188e570d01465 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -310,6 +310,35 @@ llvm.func @test_omp_master() -> () {
 
 // -----
 
+// CHECK-LABEL: define void @test_omp_masked({{.*}})
+llvm.func @test_omp_masked(%arg0: i32)-> () {
+// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @{{.*}})
+// CHECK: omp.par.region1:
+  omp.parallel {
+    omp.masked filter(%arg0: i32) {
+// CHECK: [[OMP_THREAD_3_4:%.*]] = call i32 @__kmpc_global_thread_num(ptr @{{[0-9]+}})
+// CHECK: {{[0-9]+}} = call i32 @__kmpc_masked(ptr @{{[0-9]+}}, i32 [[OMP_THREAD_3_4]], i32 %{{[0-9]+}})
+// CHECK: omp.masked.region
+// CHECK: call void @__kmpc_end_masked(ptr @{{[0-9]+}}, i32 [[OMP_THREAD_3_4]])
+// CHECK: br label %omp_region.end
+      omp.terminator
+    }
+    omp.terminator
+  }
+  omp.parallel {
+    omp.parallel {
+      omp.masked filter(%arg0: i32){
+        omp.terminator
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
 // CHECK: %struct.ident_t = type
 // CHECK: @[[$loc:.*]] = private unnamed_addr constant {{.*}} c";unknown;unknown;{{[0-9]+}};{{[0-9]+}};;\00"
 // CHECK: @[[$loc_struct:.*]] = private unnamed_addr constant %struct.ident_t {{.*}} @[[$loc]] {{.*}}

``````````

</details>


https://github.com/llvm/llvm-project/pull/98401


More information about the Mlir-commits mailing list