[flang-commits] [flang] [flang] Use LLVM dialect ops for stack save/restore in target-rewrite (PR #107879)

via flang-commits flang-commits at lists.llvm.org
Mon Sep 9 08:48:55 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (jeanPerier)

<details>
<summary>Changes</summary>

Mostly NFC, I was bothered by the declaration that were always made even if unsued, and I think using LLVM Ops is nicer anyway with regards to side effects here.

```
func.func private @<!-- -->llvm.stacksave.p0() -> !fir.ref<i8>
func.func private @<!-- -->llvm.stackrestore.p0(!fir.ref<i8>)
```

There are other places in lowering that are using the calls instead of the LLVM intrinsics, but I will deal with them another time (the issue there is mostly to get the proper address space for the llvm.ptr type).

---
Full diff: https://github.com/llvm/llvm-project/pull/107879.diff


5 Files Affected:

- (modified) flang/include/flang/Optimizer/CodeGen/CGPasses.td (+1-1) 
- (modified) flang/lib/Optimizer/CodeGen/TargetRewrite.cpp (+12-11) 
- (modified) flang/test/Fir/struct-passing-x86-64-one-field-inreg.fir (+2-2) 
- (modified) flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir (+2-2) 
- (modified) flang/test/Fir/target-rewrite-complex16.fir (+8-10) 


``````````diff
diff --git a/flang/include/flang/Optimizer/CodeGen/CGPasses.td b/flang/include/flang/Optimizer/CodeGen/CGPasses.td
index e9e303df09eeba..2e097faec54036 100644
--- a/flang/include/flang/Optimizer/CodeGen/CGPasses.td
+++ b/flang/include/flang/Optimizer/CodeGen/CGPasses.td
@@ -68,7 +68,7 @@ def TargetRewritePass : Pass<"target-rewrite", "mlir::ModuleOp"> {
       representations that may differ based on the target machine.
   }];
   let dependentDialects = [ "fir::FIROpsDialect", "mlir::func::FuncDialect",
-                            "mlir::DLTIDialect" ];
+                            "mlir::DLTIDialect", "mlir::LLVM::LLVMDialect" ];
   let options = [
     Option<"forcedTargetTriple", "target", "std::string", /*default=*/"",
            "Override module's target triple.">,
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index 85bf90e4750633..a2a9cff4c4977e 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -27,6 +27,7 @@
 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
 #include "flang/Optimizer/Support/DataLayout.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -114,13 +115,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
 
     setMembers(specifics.get(), &rewriter, &*dl);
 
-    // We may need to call stacksave/stackrestore later, so
-    // create the FuncOps beforehand.
-    fir::FirOpBuilder builder(rewriter, mod);
-    builder.setInsertionPointToStart(mod.getBody());
-    stackSaveFn = fir::factory::getLlvmStackSave(builder);
-    stackRestoreFn = fir::factory::getLlvmStackRestore(builder);
-
     // Perform type conversion on signatures and call sites.
     if (mlir::failed(convertTypes(mod))) {
       mlir::emitError(mlir::UnknownLoc::get(&context),
@@ -1242,22 +1236,29 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
 
   inline void clearMembers() { setMembers(nullptr, nullptr, nullptr); }
 
+  uint64_t getAllocaAddressSpace() const {
+    if (dataLayout)
+      if (mlir::Attribute addrSpace = dataLayout->getAllocaMemorySpace())
+        return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
+    return 0;
+  }
+
   // Inserts a call to llvm.stacksave at the current insertion
   // point and the given location. Returns the call's result Value.
   inline mlir::Value genStackSave(mlir::Location loc) {
-    return rewriter->create<fir::CallOp>(loc, stackSaveFn).getResult(0);
+    mlir::Type voidPtr = mlir::LLVM::LLVMPointerType::get(
+        rewriter->getContext(), getAllocaAddressSpace());
+    return rewriter->create<mlir::LLVM::StackSaveOp>(loc, voidPtr);
   }
 
   // Inserts a call to llvm.stackrestore at the current insertion
   // point and the given location and argument.
   inline void genStackRestore(mlir::Location loc, mlir::Value sp) {
-    rewriter->create<fir::CallOp>(loc, stackRestoreFn, mlir::ValueRange{sp});
+    rewriter->create<mlir::LLVM::StackRestoreOp>(loc, sp);
   }
 
   fir::CodeGenSpecifics *specifics = nullptr;
   mlir::OpBuilder *rewriter = nullptr;
   mlir::DataLayout *dataLayout = nullptr;
-  mlir::func::FuncOp stackSaveFn = nullptr;
-  mlir::func::FuncOp stackRestoreFn = nullptr;
 };
 } // namespace
diff --git a/flang/test/Fir/struct-passing-x86-64-one-field-inreg.fir b/flang/test/Fir/struct-passing-x86-64-one-field-inreg.fir
index 9d4745becd8523..e37e8dd4481d06 100644
--- a/flang/test/Fir/struct-passing-x86-64-one-field-inreg.fir
+++ b/flang/test/Fir/struct-passing-x86-64-one-field-inreg.fir
@@ -13,13 +13,13 @@ func.func @test_call_i16(%0 : !fir.ref<!fir.type<ti16{i:i16}>>) {
 // CHECK-LABEL:   func.func @test_call_i16(
 // CHECK-SAME:                             %[[VAL_0:.*]]: !fir.ref<!fir.type<ti16{i:i16}>>) {
 // CHECK:           %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.type<ti16{i:i16}>>
-// CHECK:           %[[VAL_2:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK:           %[[VAL_2:.*]] = llvm.intr.stacksave : !llvm.ptr
 // CHECK:           %[[VAL_3:.*]] = fir.alloca i16
 // CHECK:           %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<i16>) -> !fir.ref<!fir.type<ti16{i:i16}>>
 // CHECK:           fir.store %[[VAL_1]] to %[[VAL_4]] : !fir.ref<!fir.type<ti16{i:i16}>>
 // CHECK:           %[[VAL_5:.*]] = fir.load %[[VAL_3]] : !fir.ref<i16>
 // CHECK:           fir.call @test_func_i16(%[[VAL_5]]) : (i16) -> ()
-// CHECK:           fir.call @llvm.stackrestore.p0(%[[VAL_2]]) : (!fir.ref<i8>) -> ()
+// CHECK:           llvm.intr.stackrestore %[[VAL_2]] : !llvm.ptr
 
 func.func private @test_func_i16(%0 : !fir.type<ti16{i:i16}>) -> () {
   return
diff --git a/flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir b/flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir
index 82139492cea700..9a0a41e1da542a 100644
--- a/flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir
+++ b/flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir
@@ -14,7 +14,7 @@ func.func @test_call_i8_a16(%0 : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}
 // CHECK-LABEL:   func.func @test_call_i8_a16(
 // CHECK-SAME:                                %[[VAL_0:.*]]: !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>) {
 // CHECK:           %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>
-// CHECK:           %[[VAL_2:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK:           %[[VAL_2:.*]] = llvm.intr.stacksave : !llvm.ptr
 // CHECK:           %[[VAL_3:.*]] = fir.alloca tuple<i64, i64>
 // CHECK:           %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<tuple<i64, i64>>) -> !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>
 // CHECK:           fir.store %[[VAL_1]] to %[[VAL_4]] : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>
@@ -22,7 +22,7 @@ func.func @test_call_i8_a16(%0 : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}
 // CHECK:           %[[VAL_6:.*]] = fir.extract_value %[[VAL_5]], [0 : i32] : (tuple<i64, i64>) -> i64
 // CHECK:           %[[VAL_7:.*]] = fir.extract_value %[[VAL_5]], [1 : i32] : (tuple<i64, i64>) -> i64
 // CHECK:           fir.call @test_func_i8_a16(%[[VAL_6]], %[[VAL_7]]) : (i64, i64) -> ()
-// CHECK:           fir.call @llvm.stackrestore.p0(%[[VAL_2]]) : (!fir.ref<i8>) -> ()
+// CHECK:           llvm.intr.stackrestore %[[VAL_2]] : !llvm.ptr
 // CHECK:           return
 
 func.func private @test_func_i8_a16(%0 : !fir.type<ti8_a16{a:!fir.array<16xi8>}>) -> () {
diff --git a/flang/test/Fir/target-rewrite-complex16.fir b/flang/test/Fir/target-rewrite-complex16.fir
index 69ee28ea337bf6..304f15a828454e 100644
--- a/flang/test/Fir/target-rewrite-complex16.fir
+++ b/flang/test/Fir/target-rewrite-complex16.fir
@@ -63,18 +63,18 @@ func.func @addrof() {
 // CHECK:         func.func private @paramcomplex16(!fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>})
 
 // CHECK-LABEL:   func.func @callcomplex16() {
-// CHECK:           %[[VAL_0:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK:           %[[VAL_0:.*]] = llvm.intr.stacksave : !llvm.ptr
 // CHECK:           %[[VAL_1:.*]] = fir.alloca tuple<!fir.real<16>, !fir.real<16>>
 // CHECK:           fir.call @returncomplex16(%[[VAL_1]]) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
 // CHECK:           %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> !fir.ref<!fir.complex<16>>
 // CHECK:           %[[VAL_3:.*]] = fir.load %[[VAL_2]] : !fir.ref<!fir.complex<16>>
-// CHECK:           fir.call @llvm.stackrestore.p0(%[[VAL_0]]) : (!fir.ref<i8>) -> ()
-// CHECK:           %[[VAL_4:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK:           llvm.intr.stackrestore %[[VAL_0]] : !llvm.ptr
+// CHECK:           %[[VAL_4:.*]] = llvm.intr.stacksave : !llvm.ptr
 // CHECK:           %[[VAL_5:.*]] = fir.alloca !fir.complex<16>
 // CHECK:           fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<!fir.complex<16>>
 // CHECK:           %[[VAL_6:.*]] = fir.convert %[[VAL_5]] : (!fir.ref<!fir.complex<16>>) -> !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>
 // CHECK:           fir.call @paramcomplex16(%[[VAL_6]]) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
-// CHECK:           fir.call @llvm.stackrestore.p0(%[[VAL_4]]) : (!fir.ref<i8>) -> ()
+// CHECK:           llvm.intr.stackrestore %[[VAL_4]] : !llvm.ptr
 // CHECK:           return
 // CHECK:         }
 // CHECK:         func.func private @calleemultipleparamscomplex16(!fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>}, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>}, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>})
@@ -87,7 +87,7 @@ func.func @addrof() {
 // CHECK:           %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<!fir.complex<16>>
 // CHECK:           %[[VAL_7:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> !fir.ref<!fir.complex<16>>
 // CHECK:           %[[VAL_8:.*]] = fir.load %[[VAL_7]] : !fir.ref<!fir.complex<16>>
-// CHECK:           %[[VAL_9:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK:           %[[VAL_9:.*]] = llvm.intr.stacksave : !llvm.ptr
 // CHECK:           %[[VAL_10:.*]] = fir.alloca !fir.complex<16>
 // CHECK:           fir.store %[[VAL_8]] to %[[VAL_10]] : !fir.ref<!fir.complex<16>>
 // CHECK:           %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (!fir.ref<!fir.complex<16>>) -> !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>
@@ -98,7 +98,7 @@ func.func @addrof() {
 // CHECK:           fir.store %[[VAL_4]] to %[[VAL_14]] : !fir.ref<!fir.complex<16>>
 // CHECK:           %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (!fir.ref<!fir.complex<16>>) -> !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>
 // CHECK:           fir.call @calleemultipleparamscomplex16(%[[VAL_11]], %[[VAL_13]], %[[VAL_15]]) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
-// CHECK:           fir.call @llvm.stackrestore.p0(%[[VAL_9]]) : (!fir.ref<i8>) -> ()
+// CHECK:           llvm.intr.stackrestore %[[VAL_9]] : !llvm.ptr
 // CHECK:           return
 // CHECK:         }
 
@@ -108,7 +108,7 @@ func.func @addrof() {
 // CHECK:           %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref<complex<f128>>
 // CHECK:           %[[VAL_5:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>
 // CHECK:           %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<complex<f128>>
-// CHECK:           %[[VAL_7:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK:           %[[VAL_7:.*]] = llvm.intr.stacksave : !llvm.ptr
 // CHECK:           %[[VAL_8:.*]] = fir.alloca tuple<f128, f128>
 // CHECK:           %[[VAL_9:.*]] = fir.alloca complex<f128>
 // CHECK:           fir.store %[[VAL_6]] to %[[VAL_9]] : !fir.ref<complex<f128>>
@@ -119,7 +119,7 @@ func.func @addrof() {
 // CHECK:           fir.call @mlircomplexf128(%[[VAL_8]], %[[VAL_10]], %[[VAL_12]]) : (!fir.ref<tuple<f128, f128>>, !fir.ref<tuple<f128, f128>>, !fir.ref<tuple<f128, f128>>) -> ()
 // CHECK:           %[[VAL_13:.*]] = fir.convert %[[VAL_8]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>
 // CHECK:           %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<complex<f128>>
-// CHECK:           fir.call @llvm.stackrestore.p0(%[[VAL_7]]) : (!fir.ref<i8>) -> ()
+// CHECK:           llvm.intr.stackrestore %[[VAL_7]] : !llvm.ptr
 // CHECK:           %[[VAL_15:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>
 // CHECK:           fir.store %[[VAL_14]] to %[[VAL_15]] : !fir.ref<complex<f128>>
 // CHECK:           return
@@ -130,5 +130,3 @@ func.func @addrof() {
 // CHECK:           %[[VAL_1:.*]] = fir.address_of(@paramcomplex16) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
 // CHECK:           return
 // CHECK:         }
-// CHECK:         func.func private @llvm.stacksave.p0() -> !fir.ref<i8>
-// CHECK:         func.func private @llvm.stackrestore.p0(!fir.ref<i8>)

``````````

</details>


https://github.com/llvm/llvm-project/pull/107879


More information about the flang-commits mailing list