[flang-commits] [flang] [flang] Use LLVM dialect ops for stack save/restore in target-rewrite (PR #107879)
via flang-commits
flang-commits at lists.llvm.org
Mon Sep 9 08:48:23 PDT 2024
https://github.com/jeanPerier created https://github.com/llvm/llvm-project/pull/107879
Mostly NFC, I was bothered by the declaration that were always made even if unsued, and I think using LLVM Ops is nicer anyway with regards to side effects here.
```
func.func private @llvm.stacksave.p0() -> !fir.ref<i8>
func.func private @llvm.stackrestore.p0(!fir.ref<i8>)
```
There are other places in lowering that are using the calls instead of the LLVM intrinsics, but I will deal with them another time (the issue there is mostly to get the proper address space for the llvm.ptr type).
>From 10c3f3a9f96c3d1a1a00e84e15114e10172a0ac5 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Mon, 9 Sep 2024 08:41:57 -0700
Subject: [PATCH] [flang] Use LLVM dialect ops for stack save/restore in
target-rewrite
---
.../flang/Optimizer/CodeGen/CGPasses.td | 2 +-
flang/lib/Optimizer/CodeGen/TargetRewrite.cpp | 23 ++++++++++---------
.../struct-passing-x86-64-one-field-inreg.fir | 4 ++--
...ct-passing-x86-64-several-fields-inreg.fir | 4 ++--
flang/test/Fir/target-rewrite-complex16.fir | 18 +++++++--------
5 files changed, 25 insertions(+), 26 deletions(-)
diff --git a/flang/include/flang/Optimizer/CodeGen/CGPasses.td b/flang/include/flang/Optimizer/CodeGen/CGPasses.td
index e9e303df09eeba..2e097faec54036 100644
--- a/flang/include/flang/Optimizer/CodeGen/CGPasses.td
+++ b/flang/include/flang/Optimizer/CodeGen/CGPasses.td
@@ -68,7 +68,7 @@ def TargetRewritePass : Pass<"target-rewrite", "mlir::ModuleOp"> {
representations that may differ based on the target machine.
}];
let dependentDialects = [ "fir::FIROpsDialect", "mlir::func::FuncDialect",
- "mlir::DLTIDialect" ];
+ "mlir::DLTIDialect", "mlir::LLVM::LLVMDialect" ];
let options = [
Option<"forcedTargetTriple", "target", "std::string", /*default=*/"",
"Override module's target triple.">,
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index 85bf90e4750633..a2a9cff4c4977e 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -27,6 +27,7 @@
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/Support/DataLayout.h"
#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -114,13 +115,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
setMembers(specifics.get(), &rewriter, &*dl);
- // We may need to call stacksave/stackrestore later, so
- // create the FuncOps beforehand.
- fir::FirOpBuilder builder(rewriter, mod);
- builder.setInsertionPointToStart(mod.getBody());
- stackSaveFn = fir::factory::getLlvmStackSave(builder);
- stackRestoreFn = fir::factory::getLlvmStackRestore(builder);
-
// Perform type conversion on signatures and call sites.
if (mlir::failed(convertTypes(mod))) {
mlir::emitError(mlir::UnknownLoc::get(&context),
@@ -1242,22 +1236,29 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
inline void clearMembers() { setMembers(nullptr, nullptr, nullptr); }
+ uint64_t getAllocaAddressSpace() const {
+ if (dataLayout)
+ if (mlir::Attribute addrSpace = dataLayout->getAllocaMemorySpace())
+ return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
+ return 0;
+ }
+
// Inserts a call to llvm.stacksave at the current insertion
// point and the given location. Returns the call's result Value.
inline mlir::Value genStackSave(mlir::Location loc) {
- return rewriter->create<fir::CallOp>(loc, stackSaveFn).getResult(0);
+ mlir::Type voidPtr = mlir::LLVM::LLVMPointerType::get(
+ rewriter->getContext(), getAllocaAddressSpace());
+ return rewriter->create<mlir::LLVM::StackSaveOp>(loc, voidPtr);
}
// Inserts a call to llvm.stackrestore at the current insertion
// point and the given location and argument.
inline void genStackRestore(mlir::Location loc, mlir::Value sp) {
- rewriter->create<fir::CallOp>(loc, stackRestoreFn, mlir::ValueRange{sp});
+ rewriter->create<mlir::LLVM::StackRestoreOp>(loc, sp);
}
fir::CodeGenSpecifics *specifics = nullptr;
mlir::OpBuilder *rewriter = nullptr;
mlir::DataLayout *dataLayout = nullptr;
- mlir::func::FuncOp stackSaveFn = nullptr;
- mlir::func::FuncOp stackRestoreFn = nullptr;
};
} // namespace
diff --git a/flang/test/Fir/struct-passing-x86-64-one-field-inreg.fir b/flang/test/Fir/struct-passing-x86-64-one-field-inreg.fir
index 9d4745becd8523..e37e8dd4481d06 100644
--- a/flang/test/Fir/struct-passing-x86-64-one-field-inreg.fir
+++ b/flang/test/Fir/struct-passing-x86-64-one-field-inreg.fir
@@ -13,13 +13,13 @@ func.func @test_call_i16(%0 : !fir.ref<!fir.type<ti16{i:i16}>>) {
// CHECK-LABEL: func.func @test_call_i16(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.type<ti16{i:i16}>>) {
// CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.type<ti16{i:i16}>>
-// CHECK: %[[VAL_2:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK: %[[VAL_2:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_3:.*]] = fir.alloca i16
// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<i16>) -> !fir.ref<!fir.type<ti16{i:i16}>>
// CHECK: fir.store %[[VAL_1]] to %[[VAL_4]] : !fir.ref<!fir.type<ti16{i:i16}>>
// CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_3]] : !fir.ref<i16>
// CHECK: fir.call @test_func_i16(%[[VAL_5]]) : (i16) -> ()
-// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_2]]) : (!fir.ref<i8>) -> ()
+// CHECK: llvm.intr.stackrestore %[[VAL_2]] : !llvm.ptr
func.func private @test_func_i16(%0 : !fir.type<ti16{i:i16}>) -> () {
return
diff --git a/flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir b/flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir
index 82139492cea700..9a0a41e1da542a 100644
--- a/flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir
+++ b/flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir
@@ -14,7 +14,7 @@ func.func @test_call_i8_a16(%0 : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}
// CHECK-LABEL: func.func @test_call_i8_a16(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>) {
// CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>
-// CHECK: %[[VAL_2:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK: %[[VAL_2:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_3:.*]] = fir.alloca tuple<i64, i64>
// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<tuple<i64, i64>>) -> !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>
// CHECK: fir.store %[[VAL_1]] to %[[VAL_4]] : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>
@@ -22,7 +22,7 @@ func.func @test_call_i8_a16(%0 : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}
// CHECK: %[[VAL_6:.*]] = fir.extract_value %[[VAL_5]], [0 : i32] : (tuple<i64, i64>) -> i64
// CHECK: %[[VAL_7:.*]] = fir.extract_value %[[VAL_5]], [1 : i32] : (tuple<i64, i64>) -> i64
// CHECK: fir.call @test_func_i8_a16(%[[VAL_6]], %[[VAL_7]]) : (i64, i64) -> ()
-// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_2]]) : (!fir.ref<i8>) -> ()
+// CHECK: llvm.intr.stackrestore %[[VAL_2]] : !llvm.ptr
// CHECK: return
func.func private @test_func_i8_a16(%0 : !fir.type<ti8_a16{a:!fir.array<16xi8>}>) -> () {
diff --git a/flang/test/Fir/target-rewrite-complex16.fir b/flang/test/Fir/target-rewrite-complex16.fir
index 69ee28ea337bf6..304f15a828454e 100644
--- a/flang/test/Fir/target-rewrite-complex16.fir
+++ b/flang/test/Fir/target-rewrite-complex16.fir
@@ -63,18 +63,18 @@ func.func @addrof() {
// CHECK: func.func private @paramcomplex16(!fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>})
// CHECK-LABEL: func.func @callcomplex16() {
-// CHECK: %[[VAL_0:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK: %[[VAL_0:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_1:.*]] = fir.alloca tuple<!fir.real<16>, !fir.real<16>>
// CHECK: fir.call @returncomplex16(%[[VAL_1]]) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_2]] : !fir.ref<!fir.complex<16>>
-// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_0]]) : (!fir.ref<i8>) -> ()
-// CHECK: %[[VAL_4:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK: llvm.intr.stackrestore %[[VAL_0]] : !llvm.ptr
+// CHECK: %[[VAL_4:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_5:.*]] = fir.alloca !fir.complex<16>
// CHECK: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_5]] : (!fir.ref<!fir.complex<16>>) -> !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>
// CHECK: fir.call @paramcomplex16(%[[VAL_6]]) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
-// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_4]]) : (!fir.ref<i8>) -> ()
+// CHECK: llvm.intr.stackrestore %[[VAL_4]] : !llvm.ptr
// CHECK: return
// CHECK: }
// CHECK: func.func private @calleemultipleparamscomplex16(!fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>}, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>}, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>})
@@ -87,7 +87,7 @@ func.func @addrof() {
// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_7]] : !fir.ref<!fir.complex<16>>
-// CHECK: %[[VAL_9:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK: %[[VAL_9:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_10:.*]] = fir.alloca !fir.complex<16>
// CHECK: fir.store %[[VAL_8]] to %[[VAL_10]] : !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (!fir.ref<!fir.complex<16>>) -> !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>
@@ -98,7 +98,7 @@ func.func @addrof() {
// CHECK: fir.store %[[VAL_4]] to %[[VAL_14]] : !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (!fir.ref<!fir.complex<16>>) -> !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>
// CHECK: fir.call @calleemultipleparamscomplex16(%[[VAL_11]], %[[VAL_13]], %[[VAL_15]]) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
-// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_9]]) : (!fir.ref<i8>) -> ()
+// CHECK: llvm.intr.stackrestore %[[VAL_9]] : !llvm.ptr
// CHECK: return
// CHECK: }
@@ -108,7 +108,7 @@ func.func @addrof() {
// CHECK: %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref<complex<f128>>
// CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>
// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<complex<f128>>
-// CHECK: %[[VAL_7:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK: %[[VAL_7:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_8:.*]] = fir.alloca tuple<f128, f128>
// CHECK: %[[VAL_9:.*]] = fir.alloca complex<f128>
// CHECK: fir.store %[[VAL_6]] to %[[VAL_9]] : !fir.ref<complex<f128>>
@@ -119,7 +119,7 @@ func.func @addrof() {
// CHECK: fir.call @mlircomplexf128(%[[VAL_8]], %[[VAL_10]], %[[VAL_12]]) : (!fir.ref<tuple<f128, f128>>, !fir.ref<tuple<f128, f128>>, !fir.ref<tuple<f128, f128>>) -> ()
// CHECK: %[[VAL_13:.*]] = fir.convert %[[VAL_8]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>
// CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<complex<f128>>
-// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_7]]) : (!fir.ref<i8>) -> ()
+// CHECK: llvm.intr.stackrestore %[[VAL_7]] : !llvm.ptr
// CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>
// CHECK: fir.store %[[VAL_14]] to %[[VAL_15]] : !fir.ref<complex<f128>>
// CHECK: return
@@ -130,5 +130,3 @@ func.func @addrof() {
// CHECK: %[[VAL_1:.*]] = fir.address_of(@paramcomplex16) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
// CHECK: return
// CHECK: }
-// CHECK: func.func private @llvm.stacksave.p0() -> !fir.ref<i8>
-// CHECK: func.func private @llvm.stackrestore.p0(!fir.ref<i8>)
More information about the flang-commits
mailing list