[flang-commits] [flang] [flang] Fix the incorrect insertion point for alloca (PR #65999)

Hao Jin via flang-commits flang-commits at lists.llvm.org
Mon Sep 11 12:25:32 PDT 2023


https://github.com/erjin created https://github.com/llvm/llvm-project/pull/65999:

While creating a temporary alloca for a box in OpenMp region, the insertion point should be the OpenMP region block instead of the function entry block.

>From 12bd6dc23bf26431d5e0757c111dfc51d4170e0a Mon Sep 17 00:00:00 2001
From: Hao Jin <hao.jin at huawei.com>
Date: Mon, 11 Sep 2023 14:56:31 -0400
Subject: [PATCH] [flang] Fix the incorrect insertion point for alloca

While creating a temporary alloca for a box in OpenMp region, the
insertion point should be the OpenMP region block instead of the
function entry block.
---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       | 29 +++++++-----
 .../Fir/convert-to-llvm-openmp-and-fir.fir    | 44 ++++++++++++++++++-
 2 files changed, 60 insertions(+), 13 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index e51050b5a594a61..f8b76f95f56953d 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -333,16 +333,20 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
     return rewriter.create<mlir::LLVM::GEPOp>(loc, ty, base, cv);
   }
 
-  // Find the LLVMFuncOp in whose entry block the alloca should be inserted.
-  // The order to find the LLVMFuncOp is as follows:
-  // 1. The parent operation of the current block if it is a LLVMFuncOp.
-  // 2. The first ancestor that is a LLVMFuncOp.
-  mlir::LLVM::LLVMFuncOp
-  getFuncForAllocaInsert(mlir::ConversionPatternRewriter &rewriter) const {
-    mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
-    return mlir::isa<mlir::LLVM::LLVMFuncOp>(parentOp)
-               ? mlir::cast<mlir::LLVM::LLVMFuncOp>(parentOp)
-               : parentOp->getParentOfType<mlir::LLVM::LLVMFuncOp>();
+  // Find the Block in which the alloca should be inserted.
+  // The order to recursively find the proper block:
+  // 1. An OpenMP Op
+  // 2. A LLVMFuncOp
+  // 3. The first ancestor that is an OpenMP Op or a LLVMFuncOp
+  static mlir::Block *getBlockForAllocaInsert(mlir::Operation *op) {
+    if (auto iface =
+            mlir::dyn_cast<mlir::omp::OutlineableOpenMPOpInterface>(op)) {
+      return iface.getAllocaBlock();
+    } else if (auto llvmFuncOp = mlir::dyn_cast<mlir::LLVM::LLVMFuncOp>(op)) {
+      return &llvmFuncOp.front();
+    } else {
+      return getBlockForAllocaInsert(op->getParentOp());
+    }
   }
 
   // Generate an alloca of size 1 and type \p toTy.
@@ -350,8 +354,9 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
   genAllocaWithType(mlir::Location loc, mlir::Type toTy, unsigned alignment,
                     mlir::ConversionPatternRewriter &rewriter) const {
     auto thisPt = rewriter.saveInsertionPoint();
-    mlir::LLVM::LLVMFuncOp func = getFuncForAllocaInsert(rewriter);
-    rewriter.setInsertionPointToStart(&func.front());
+    mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
+    mlir::Block *insertBlock = getBlockForAllocaInsert(parentOp);
+    rewriter.setInsertionPointToStart(insertBlock);
     auto size = genI32Constant(loc, rewriter, 1);
     auto al = rewriter.create<mlir::LLVM::AllocaOp>(loc, toTy, size, alignment);
     rewriter.restoreInsertionPoint(thisPt);
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index 06fc1d0edbe2e62..5863cc4fb78db7f 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -654,4 +654,46 @@ func.func @_QPs(%arg0: !fir.ref<!fir.complex<4>> {fir.bindc_name = "x"}) {
   %0 = fir.alloca !fir.complex<4> {bindc_name = "v", uniq_name = "_QFsEv"}
   omp.atomic.read %0 = %arg0   : !fir.ref<!fir.complex<4>>, !fir.complex<4>
   return
-} 
+}
+
+// -----
+// Test if llvm.alloca is properly inserted in the omp section
+
+//CHECK:  %[[CONST:.*]] = llvm.mlir.constant(1 : i64) : i64
+//CHECK:  %[[ALLOCA:.*]] = llvm.alloca %[[CONST]] x !llvm.struct<(ptr<i32>, i64, i32, i8, i8, i8, i8)> {bindc_name = "iattr", in_type = !fir.box<!fir.ptr<i32>>, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFEiattr"} : (i64) -> !llvm.ptr<struct<(ptr<i32>, i64, i32, i8, i8, i8, i8)>>
+//CHECK:  %[[CONST_1:.*]] = llvm.mlir.constant(1 : i64) : i64
+//CHECK:  %[[ALLOCA_1:.*]] = llvm.alloca %[[CONST_1]] x i32 {bindc_name = "y", in_type = i32, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFEy"} : (i64) -> !llvm.ptr<i32>
+//CHECK:  omp.parallel   {
+//CHECK:    %[[CONST_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+//CHECK:    %[[ALLOCA_2:.*]] = llvm.alloca %[[CONST_2:.*]] x !llvm.struct<(ptr<i32>, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<struct<(ptr<i32>, i64, i32, i8, i8, i8, i8)>>
+//CHECK:    %[[LOAD:.*]] = llvm.load %[[ALLOCA_1]] : !llvm.ptr<i32>
+//CHECK:    %[[LOAD_1:.*]] = llvm.load %[[ALLOCA]] : !llvm.ptr<struct<(ptr<i32>, i64, i32, i8, i8, i8, i8)>>
+//CHECK:    llvm.store %[[LOAD_1]], %[[ALLOCA_2]] : !llvm.ptr<struct<(ptr<i32>, i64, i32, i8, i8, i8, i8)>>
+//CHECK:    %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA_2]][0, 0] : (!llvm.ptr<struct<(ptr<i32>, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr<ptr<i32>>
+//CHECK:    %[[LOAD_2:.*]] = llvm.load %[[GEP]] : !llvm.ptr<ptr<i32>>
+//CHECK:    llvm.store %[[LOAD]], %[[LOAD_2]] : !llvm.ptr<i32>
+//CHECK:    omp.terminator
+//CHECK:  }
+
+func.func @_QQmain() attributes {fir.bindc_name = "mn"} {
+  %0 = fir.alloca !fir.box<!fir.ptr<i32>> {bindc_name = "iattr", uniq_name = "_QFEiattr"}
+  %1 = fir.zero_bits !fir.ptr<i32>
+  %2 = fir.embox %1 : (!fir.ptr<i32>) -> !fir.box<!fir.ptr<i32>>
+  fir.store %2 to %0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
+  %3 = fir.address_of(@_QFEx) : !fir.ref<i32>
+  %4 = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"}
+  %5 = fir.embox %3 : (!fir.ref<i32>) -> !fir.box<!fir.ptr<i32>>
+  fir.store %5 to %0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
+  omp.parallel   {
+    %6 = fir.load %4 : !fir.ref<i32>
+    %7 = fir.load %0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
+    %8 = fir.box_addr %7 : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
+    fir.store %6 to %8 : !fir.ptr<i32>
+    omp.terminator
+  }
+  return
+}
+fir.global internal @_QFEx target : i32 {
+  %0 = fir.zero_bits i32
+  fir.has_value %0 : i32
+}



More information about the flang-commits mailing list