[llvm-branch-commits] [mlir] f7d033f - [mlir] Support WsLoopOp in OpenMP to LLVM dialect conversion

Alex Zinenko via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Nov 23 14:32:40 PST 2020


Author: Alex Zinenko
Date: 2020-11-23T23:28:02+01:00
New Revision: f7d033f4d80f476246a70f165e7455639818f907

URL: https://github.com/llvm/llvm-project/commit/f7d033f4d80f476246a70f165e7455639818f907
DIFF: https://github.com/llvm/llvm-project/commit/f7d033f4d80f476246a70f165e7455639818f907.diff

LOG: [mlir] Support WsLoopOp in OpenMP to LLVM dialect conversion

It is a simple conversion that only requires to change the region argument
types, generalize it from ParallelOp.

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D91989

Added: 
    

Modified: 
    mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
    mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index cfb553da407c..91e97ca1ec50 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -16,18 +16,23 @@
 using namespace mlir;
 
 namespace {
-struct ParallelOpConversion : public ConvertToLLVMPattern {
-  explicit ParallelOpConversion(MLIRContext *context,
-                                LLVMTypeConverter &typeConverter)
-      : ConvertToLLVMPattern(omp::ParallelOp::getOperationName(), context,
+/// A pattern that converts the region arguments in a single-region OpenMP
+/// operation to the LLVM dialect. The body of the region is not modified and is
+/// expected to either be processed by the conversion infrastructure or already
+/// contain ops compatible with LLVM dialect types.
+template <typename OpType>
+struct RegionOpConversion : public ConvertToLLVMPattern {
+  explicit RegionOpConversion(MLIRContext *context,
+                              LLVMTypeConverter &typeConverter)
+      : ConvertToLLVMPattern(OpType::getOperationName(), context,
                              typeConverter) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto curOp = cast<omp::ParallelOp>(op);
-    auto newOp = rewriter.create<omp::ParallelOp>(curOp.getLoc(), TypeRange(),
-                                                  operands, curOp.getAttrs());
+    auto curOp = cast<OpType>(op);
+    auto newOp = rewriter.create<OpType>(curOp.getLoc(), TypeRange(), operands,
+                                         curOp.getAttrs());
     rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
                                 newOp.region().end());
     if (failed(rewriter.convertRegionTypes(&newOp.region(), typeConverter)))
@@ -42,7 +47,8 @@ struct ParallelOpConversion : public ConvertToLLVMPattern {
 void mlir::populateOpenMPToLLVMConversionPatterns(
     MLIRContext *context, LLVMTypeConverter &converter,
     OwningRewritePatternList &patterns) {
-  patterns.insert<ParallelOpConversion>(context, converter);
+  patterns.insert<RegionOpConversion<omp::ParallelOp>,
+                  RegionOpConversion<omp::WsLoopOp>>(context, converter);
 }
 
 namespace {
@@ -63,8 +69,8 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
   populateOpenMPToLLVMConversionPatterns(context, converter, patterns);
 
   LLVMConversionTarget target(getContext());
-  target.addDynamicallyLegalOp<omp::ParallelOp>(
-      [&](omp::ParallelOp op) { return converter.isLegal(&op.getRegion()); });
+  target.addDynamicallyLegalOp<omp::ParallelOp, omp::WsLoopOp>(
+      [&](Operation *op) { return converter.isLegal(&op->getRegion(0)); });
   target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
                     omp::BarrierOp, omp::TaskwaitOp>();
   if (failed(applyPartialConversion(module, target, std::move(patterns))))

diff  --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index d38a6ea7e3a9..62ea39f078b2 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -28,3 +28,22 @@ func @branch_loop() {
   }
   return
 }
+
+// CHECK-LABEL: @wsloop
+// CHECK: (%[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64, %[[ARG3:.*]]: !llvm.i64, %[[ARG4:.*]]: !llvm.i64, %[[ARG5:.*]]: !llvm.i64)
+func @wsloop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) {
+  // CHECK: omp.parallel
+  omp.parallel {
+    // CHECK: omp.wsloop
+    // CHECK: (%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]])
+    "omp.wsloop"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) ( {
+    // CHECK: ^{{.*}}(%[[ARG6:.*]]: !llvm.i64, %[[ARG7:.*]]: !llvm.i64):
+    ^bb0(%arg6: index, %arg7: index):  // no predecessors
+      // CHECK: "test.payload"(%[[ARG6]], %[[ARG7]]) : (!llvm.i64, !llvm.i64) -> ()
+      "test.payload"(%arg6, %arg7) : (index, index) -> ()
+      omp.yield
+    }) {operand_segment_sizes = dense<[2, 2, 2, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (index, index, index, index, index, index) -> ()
+    omp.terminator
+  }
+  return
+}


        


More information about the llvm-branch-commits mailing list