[flang-commits] [flang] 22cdeb5 - [MLIR][OpenMP] Add Conversion for Atomic Update Op

Kiran Chandramohan via flang-commits flang-commits at lists.llvm.org
Thu Feb 16 15:58:08 PST 2023


Author: Kiran Chandramohan
Date: 2023-02-16T23:57:35Z
New Revision: 22cdeb54a12363a71bd3168eab6b21b735fd15c5

URL: https://github.com/llvm/llvm-project/commit/22cdeb54a12363a71bd3168eab6b21b735fd15c5
DIFF: https://github.com/llvm/llvm-project/commit/22cdeb54a12363a71bd3168eab6b21b735fd15c5.diff

LOG: [MLIR][OpenMP] Add Conversion for Atomic Update Op

Reviewed By: TIFitis

Differential Revision: https://reviews.llvm.org/D143964

Added: 
    

Modified: 
    flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
    mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index d98d12675a7a7..3764e42939c1c 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -321,3 +321,35 @@ func.func private @_QPwork()
 // CHECK:        }
 // CHECK:        llvm.func @_QPwork() attributes {sym_visibility = "private"}
 // CHECK:      }
+
+// -----
+
+func.func @_QPs() {
+  %0 = fir.address_of(@_QFsEc) : !fir.ref<i32>
+  omp.atomic.update   %0 : !fir.ref<i32> {
+  ^bb0(%arg0: i32):
+    %c1_i32 = arith.constant 1 : i32
+    %1 = arith.addi %arg0, %c1_i32 : i32
+    omp.yield(%1 : i32)
+  }
+  return
+}
+fir.global internal @_QFsEc : i32 {
+  %c10_i32 = arith.constant 10 : i32
+  fir.has_value %c10_i32 : i32
+}
+
+// CHECK-LABEL:  llvm.func @_QPs() {
+// CHECK:    %[[GLOBAL_VAR:.*]] = llvm.mlir.addressof @[[GLOBAL:.*]] : !llvm.ptr<i32>
+// CHECK:    omp.atomic.update   %[[GLOBAL_VAR]] : !llvm.ptr<i32> {
+// CHECK:    ^bb0(%[[IN_VAL:.*]]: i32):
+// CHECK:      %[[CONST_1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:      %[[OUT_VAL:.*]] = llvm.add %[[IN_VAL]], %[[CONST_1]]  : i32
+// CHECK:      omp.yield(%[[OUT_VAL]] : i32)
+// CHECK:    }
+// CHECK:    llvm.return
+// CHECK:  }
+// CHECK:  llvm.mlir.global internal @[[GLOBAL]]() {{.*}} : i32 {
+// CHECK:    %[[INIT_10:.*]] = llvm.mlir.constant(10 : i32) : i32
+// CHECK:    llvm.return %[[INIT_10]] : i32
+// CHECK:  }

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index d494e89b1274a..7dec1c8126b3c 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1359,6 +1359,18 @@ def AtomicUpdateOp : OpenMP_Op<"atomic.update",
     /// Returns the new value if the operation is equivalent to just a write
     /// operation. Otherwise, returns nullptr.
     Value getWriteOpVal();
+
+    /// The number of variable operands.
+    unsigned getNumVariableOperands() {
+      assert(getX() && "expected 'x' operand");
+      return 1;
+    }
+
+    /// The i-th variable operand passed.
+    Value getVariableOperand(unsigned i) {
+      assert(i == 0 && "invalid index position for an operand");
+      return getX();
+    }
   }];
 }
 

diff  --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 822a1abd0b282..621600b268c0f 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -83,6 +83,44 @@ struct RegionLessOpWithVarOperandsConversion
   }
 };
 
+template <typename T>
+struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern<T> {
+  using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
+  LogicalResult
+  matchAndRewrite(T curOp, typename T::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
+    SmallVector<Type> resTypes;
+    if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
+      return failure();
+    SmallVector<Value> convertedOperands;
+    assert(curOp.getNumVariableOperands() ==
+               curOp.getOperation()->getNumOperands() &&
+           "unexpected non-variable operands");
+    for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
+      Value originalVariableOperand = curOp.getVariableOperand(idx);
+      if (!originalVariableOperand)
+        return failure();
+      if (originalVariableOperand.getType().isa<MemRefType>()) {
+        // TODO: Support memref type in variable operands
+        return rewriter.notifyMatchFailure(curOp,
+                                           "memref is not supported yet");
+      }
+      convertedOperands.emplace_back(adaptor.getOperands()[idx]);
+    }
+    auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands,
+                                    curOp->getAttrs());
+    rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
+                                newOp.getRegion().end());
+    if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
+                                           *this->getTypeConverter())))
+      return failure();
+
+    rewriter.eraseOp(curOp);
+    return success();
+  }
+};
+
 struct ReductionOpConversion : public ConvertOpToLLVMPattern<omp::ReductionOp> {
   using ConvertOpToLLVMPattern<omp::ReductionOp>::ConvertOpToLLVMPattern;
   LogicalResult
@@ -114,13 +152,14 @@ struct LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> {
 void mlir::configureOpenMPToLLVMConversionLegality(
     ConversionTarget &target, LLVMTypeConverter &typeConverter) {
   target.addDynamicallyLegalOp<
-      mlir::omp::CriticalOp, mlir::omp::ParallelOp, mlir::omp::WsLoopOp,
-      mlir::omp::SimdLoopOp, mlir::omp::MasterOp, mlir::omp::SectionsOp,
-      mlir::omp::SingleOp, mlir::omp::TaskOp>([&](Operation *op) {
-    return typeConverter.isLegal(&op->getRegion(0)) &&
-           typeConverter.isLegal(op->getOperandTypes()) &&
-           typeConverter.isLegal(op->getResultTypes());
-  });
+      mlir::omp::AtomicUpdateOp, mlir::omp::CriticalOp, mlir::omp::ParallelOp,
+      mlir::omp::WsLoopOp, mlir::omp::SimdLoopOp, mlir::omp::MasterOp,
+      mlir::omp::SectionsOp, mlir::omp::SingleOp, mlir::omp::TaskOp>(
+      [&](Operation *op) {
+        return typeConverter.isLegal(&op->getRegion(0)) &&
+               typeConverter.isLegal(op->getOperandTypes()) &&
+               typeConverter.isLegal(op->getResultTypes());
+      });
   target.addDynamicallyLegalOp<mlir::omp::AtomicReadOp,
                                mlir::omp::AtomicWriteOp, mlir::omp::FlushOp,
                                mlir::omp::ThreadprivateOp, mlir::omp::DataOp,
@@ -145,6 +184,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
       RegionOpConversion<omp::TaskOp>,
       RegionLessOpWithVarOperandsConversion<omp::AtomicReadOp>,
       RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
+      RegionOpWithVarOperandsConversion<omp::AtomicUpdateOp>,
       RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
       RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>,
       LegalizeDataOpForLLVMTranslation<omp::DataOp>,

diff  --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index 354c67912377b..74c1b19ea5102 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -109,6 +109,32 @@ func.func @atomic_read(%a: !llvm.ptr<i32>, %b: !llvm.ptr<i32>) -> () {
 
 // -----
 
+func.func @atomic_update() {
+  %0 = llvm.mlir.addressof @_QFsEc : !llvm.ptr<i32>
+  omp.atomic.update   %0 : !llvm.ptr<i32> {
+  ^bb0(%arg0: i32):
+    %1 = arith.constant 1 : i32
+    %2 = arith.addi %arg0, %1  : i32
+    omp.yield(%2 : i32)
+  }
+  return
+}
+llvm.mlir.global internal @_QFsEc() : i32 {
+  %0 = arith.constant 10 : i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: @atomic_update
+// CHECK: %[[GLOBAL_VAR:.*]] = llvm.mlir.addressof @_QFsEc : !llvm.ptr<i32>
+// CHECK: omp.atomic.update   %[[GLOBAL_VAR]] : !llvm.ptr<i32> {
+// CHECK: ^bb0(%[[IN_VAL:.*]]: i32):
+// CHECK:   %[[CONST_1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:   %[[OUT_VAL:.*]] = llvm.add %[[IN_VAL]], %[[CONST_1]]  : i32
+// CHECK:   omp.yield(%[[OUT_VAL]] : i32)
+// CHECK: }
+
+// -----
+
 // CHECK-LABEL: @threadprivate
 // CHECK: (%[[ARG0:.*]]: !llvm.ptr<i32>)
 // CHECK: %[[VAL0:.*]] = omp.threadprivate %[[ARG0]] : !llvm.ptr<i32> -> !llvm.ptr<i32>


        


More information about the flang-commits mailing list