[flang-commits] [flang] [flang][FIR][Mem2Reg] Add supoort for FIR. (PR #172808)
Ming Yan via flang-commits
flang-commits at lists.llvm.org
Sun Dec 21 18:58:43 PST 2025
https://github.com/NexMing updated https://github.com/llvm/llvm-project/pull/172808
>From 0f73278dc23e2ab8fba666fbebcedbde45289c45 Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Thu, 18 Dec 2025 14:55:08 +0800
Subject: [PATCH 1/3] [flang][FIR][Mem2Reg] Add supoort for FIR. This patch
implements Mem2Reg interfaces for FIR.
---
.../include/flang/Optimizer/Dialect/FIROps.td | 20 +++-
.../include/flang/Optimizer/Support/InitFIR.h | 1 +
flang/lib/Optimizer/Dialect/FIROps.cpp | 94 +++++++++++++++++++
flang/test/Fir/mem2reg.mlir | 68 ++++++++++++++
4 files changed, 178 insertions(+), 5 deletions(-)
create mode 100644 flang/test/Fir/mem2reg.mlir
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index cfce9fca504ec..7bfb304b973f5 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -17,6 +17,7 @@
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
+include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td"
include "flang/Optimizer/Dialect/FIRDialect.td"
@@ -80,7 +81,10 @@ def AnyRefOfConstantSizeAggregateType : TypeConstraint<
// Memory SSA operations
//===----------------------------------------------------------------------===//
-def fir_AllocaOp : fir_Op<"alloca", [AttrSizedOperandSegments]> {
+def fir_AllocaOp : fir_Op<"alloca", [
+ AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<PromotableAllocationOpInterface>
+]> {
let summary = "allocate storage for a temporary on the stack given a type";
let description = [{
This primitive operation is used to allocate an object on the stack. A
@@ -288,8 +292,11 @@ def fir_FreeMemOp : fir_Op<"freemem", [MemoryEffects<[MemFree]>]> {
let assemblyFormat = "$heapref attr-dict `:` qualified(type($heapref))";
}
-def fir_LoadOp : fir_OneResultOp<"load", [FirAliasTagOpInterface,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+def fir_LoadOp : fir_OneResultOp<"load", [
+ FirAliasTagOpInterface,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<PromotableMemOpInterface>
+]> {
let summary = "load a value from a memory reference";
let description = [{
Load a value from a memory reference into an ssa-value (virtual register).
@@ -319,8 +326,11 @@ def fir_LoadOp : fir_OneResultOp<"load", [FirAliasTagOpInterface,
}];
}
-def fir_StoreOp : fir_Op<"store", [FirAliasTagOpInterface,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+def fir_StoreOp : fir_Op<"store", [
+ FirAliasTagOpInterface,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<PromotableMemOpInterface>
+]> {
let summary = "store an SSA-value to a memory location";
let description = [{
diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h
index 67e9287ddad4f..41a979f97aece 100644
--- a/flang/include/flang/Optimizer/Support/InitFIR.h
+++ b/flang/include/flang/Optimizer/Support/InitFIR.h
@@ -129,6 +129,7 @@ inline void registerMLIRPassesForFortranTools() {
mlir::affine::registerAffineLoopTilingPass();
mlir::affine::registerAffineDataCopyGenerationPass();
+ mlir::registerMem2RegPass();
mlir::registerLowerAffinePass();
}
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 4e797d651cb7a..a004f7d9697ac 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -186,6 +186,33 @@ static mlir::Type wrapAllocaResultType(mlir::Type intype) {
return fir::ReferenceType::get(intype);
}
+llvm::SmallVector<mlir::MemorySlot> fir::AllocaOp::getPromotableSlots() {
+ // TODO: support promotion of allocas with LEN params or shape operands
+ if (hasLenParams() || hasShapeOperands())
+ return {};
+
+ return {mlir::MemorySlot{getResult(), getAllocatedType()}};
+}
+
+mlir::Value fir::AllocaOp::getDefaultValue(const mlir::MemorySlot &slot,
+ mlir::OpBuilder &builder) {
+ return fir::UndefOp::create(builder, getLoc(), slot.elemType);
+}
+
+void fir::AllocaOp::handleBlockArgument(const mlir::MemorySlot &slot,
+ mlir::BlockArgument argument,
+ mlir::OpBuilder &builder) {}
+
+std::optional<mlir::PromotableAllocationOpInterface>
+fir::AllocaOp::handlePromotionComplete(const mlir::MemorySlot &slot,
+ mlir::Value defaultValue,
+ mlir::OpBuilder &builder) {
+ if (defaultValue && defaultValue.use_empty())
+ defaultValue.getDefiningOp()->erase();
+ this->erase();
+ return std::nullopt;
+}
+
mlir::Type fir::AllocaOp::getAllocatedType() {
return mlir::cast<fir::ReferenceType>(getType()).getEleTy();
}
@@ -2861,6 +2888,40 @@ llvm::SmallVector<mlir::Attribute> fir::LenParamIndexOp::getAttributes() {
// LoadOp
//===----------------------------------------------------------------------===//
+bool fir::LoadOp::loadsFrom(const mlir::MemorySlot &slot) {
+ return getMemref() == slot.ptr;
+}
+
+bool fir::LoadOp::storesTo(const mlir::MemorySlot &slot) { return false; }
+
+mlir::Value fir::LoadOp::getStored(const mlir::MemorySlot &slot,
+ mlir::OpBuilder &builder,
+ mlir::Value reachingDef,
+ const mlir::DataLayout &dataLayout) {
+ llvm_unreachable("getStored should not be called on LoadOp");
+}
+
+bool fir::LoadOp::canUsesBeRemoved(
+ const mlir::MemorySlot &slot,
+ const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
+ mlir::SmallVectorImpl<mlir::OpOperand *> &newBlockingUses,
+ const mlir::DataLayout &dataLayout) {
+ if (blockingUses.size() != 1)
+ return false;
+ mlir::Value blockingUse = (*blockingUses.begin())->get();
+ return blockingUse == slot.ptr && getMemref() == slot.ptr &&
+ getType() == slot.elemType;
+}
+
+mlir::DeletionKind fir::LoadOp::removeBlockingUses(
+ const mlir::MemorySlot &slot,
+ const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
+ mlir::OpBuilder &builder, mlir::Value reachingDefinition,
+ const mlir::DataLayout &dataLayout) {
+ getResult().replaceAllUsesWith(reachingDefinition);
+ return mlir::DeletionKind::Delete;
+}
+
void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
mlir::Value refVal) {
if (!refVal) {
@@ -4256,6 +4317,39 @@ llvm::LogicalResult fir::SliceOp::verify() {
// StoreOp
//===----------------------------------------------------------------------===//
+bool fir::StoreOp::loadsFrom(const mlir::MemorySlot &slot) { return false; }
+
+bool fir::StoreOp::storesTo(const mlir::MemorySlot &slot) {
+ return getMemref() == slot.ptr;
+}
+
+mlir::Value fir::StoreOp::getStored(const mlir::MemorySlot &slot,
+ mlir::OpBuilder &builder,
+ mlir::Value reachingDef,
+ const mlir::DataLayout &dataLayout) {
+ return getValue();
+}
+
+bool fir::StoreOp::canUsesBeRemoved(
+ const mlir::MemorySlot &slot,
+ const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
+ mlir::SmallVectorImpl<mlir::OpOperand *> &newBlockingUses,
+ const mlir::DataLayout &dataLayout) {
+ if (blockingUses.size() != 1)
+ return false;
+ mlir::Value blockingUse = (*blockingUses.begin())->get();
+ return blockingUse == slot.ptr && getMemref() == slot.ptr &&
+ getValue() != slot.ptr && slot.elemType == getValue().getType();
+}
+
+mlir::DeletionKind fir::StoreOp::removeBlockingUses(
+ const mlir::MemorySlot &slot,
+ const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
+ mlir::OpBuilder &builder, mlir::Value reachingDefinition,
+ const mlir::DataLayout &dataLayout) {
+ return mlir::DeletionKind::Delete;
+}
+
mlir::Type fir::StoreOp::elementType(mlir::Type refType) {
return fir::dyn_cast_ptrEleTy(refType);
}
diff --git a/flang/test/Fir/mem2reg.mlir b/flang/test/Fir/mem2reg.mlir
new file mode 100644
index 0000000000000..25d114a55e1a4
--- /dev/null
+++ b/flang/test/Fir/mem2reg.mlir
@@ -0,0 +1,68 @@
+// RUN: fir-opt %s --allow-unregistered-dialect --mem2reg --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @basic() -> i32 {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 5 : i32
+// CHECK: return %[[CONSTANT_0]] : i32
+// CHECK: }
+func.func @basic() -> i32 {
+ %0 = arith.constant 5 : i32
+ %1 = fir.alloca i32
+ fir.store %0 to %1 : !fir.ref<i32>
+ %2 = fir.load %1 : !fir.ref<i32>
+ return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @default_value() -> i32 {
+// CHECK: %[[UNDEFINED_0:.*]] = fir.undefined i32
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 5 : i32
+// CHECK: return %[[UNDEFINED_0]] : i32
+// CHECK: }
+func.func @default_value() -> i32 {
+ %0 = arith.constant 5 : i32
+ %1 = fir.alloca i32
+ %2 = fir.load %1 : !fir.ref<i32>
+ fir.store %0 to %1 : !fir.ref<i32>
+ return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @basic_float() -> f32 {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 5.200000e+00 : f32
+// CHECK: return %[[CONSTANT_0]] : f32
+// CHECK: }
+func.func @basic_float() -> f32 {
+ %0 = arith.constant 5.2 : f32
+ %1 = fir.alloca f32
+ fir.store %0 to %1 : !fir.ref<f32>
+ %2 = fir.load %1 : !fir.ref<f32>
+ return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @cycle(
+// CHECK-SAME: %[[ARG0:.*]]: i64,
+// CHECK-SAME: %[[ARG1:.*]]: i1,
+// CHECK-SAME: %[[ARG2:.*]]: i64) {
+// CHECK: cf.cond_br %[[ARG1]], ^bb1(%[[ARG2]] : i64), ^bb2(%[[ARG2]] : i64)
+// CHECK: ^bb1(%[[VAL_0:.*]]: i64):
+// CHECK: "test.use"(%[[VAL_0]]) : (i64) -> ()
+// CHECK: cf.br ^bb2(%[[ARG0]] : i64)
+// CHECK: ^bb2(%[[VAL_1:.*]]: i64):
+// CHECK: cf.br ^bb1(%[[VAL_1]] : i64)
+// CHECK: }
+func.func @cycle(%arg0: i64, %arg1: i1, %arg2: i64) {
+ %alloca = fir.alloca i64
+ fir.store %arg2 to %alloca : !fir.ref<i64>
+ cf.cond_br %arg1, ^bb1, ^bb2
+^bb1:
+ %use = fir.load %alloca : !fir.ref<i64>
+ "test.use"(%use) : (i64) -> ()
+ fir.store %arg0 to %alloca : !fir.ref<i64>
+ cf.br ^bb2
+^bb2:
+ cf.br ^bb1
+}
>From b136aaafd3fb7ff9fcbf6dc8cf9fa2a45d536837 Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Fri, 19 Dec 2025 11:35:44 +0800
Subject: [PATCH 2/3] Update code
---
flang/lib/Optimizer/Dialect/FIROps.cpp | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index a004f7d9697ac..3f1e8ad3dfed1 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -207,8 +207,11 @@ std::optional<mlir::PromotableAllocationOpInterface>
fir::AllocaOp::handlePromotionComplete(const mlir::MemorySlot &slot,
mlir::Value defaultValue,
mlir::OpBuilder &builder) {
- if (defaultValue && defaultValue.use_empty())
+ if (defaultValue && defaultValue.use_empty()) {
+ assert(mlir::isa<fir::UndefOp>(defaultValue.getDefiningOp()) &&
+ "Expected undef op to be the default value");
defaultValue.getDefiningOp()->erase();
+ }
this->erase();
return std::nullopt;
}
@@ -2898,7 +2901,7 @@ mlir::Value fir::LoadOp::getStored(const mlir::MemorySlot &slot,
mlir::OpBuilder &builder,
mlir::Value reachingDef,
const mlir::DataLayout &dataLayout) {
- llvm_unreachable("getStored should not be called on LoadOp");
+ return mlir::Value();
}
bool fir::LoadOp::canUsesBeRemoved(
>From 7be2ec7ee434d7e0c0a9fa85944a0e3a327d6c69 Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Fri, 19 Dec 2025 18:24:02 +0800
Subject: [PATCH 3/3] Simplify code.
---
flang/lib/Optimizer/Dialect/FIROps.cpp | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 3f1e8ad3dfed1..c2a3d52fe88d2 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -187,8 +187,8 @@ static mlir::Type wrapAllocaResultType(mlir::Type intype) {
}
llvm::SmallVector<mlir::MemorySlot> fir::AllocaOp::getPromotableSlots() {
- // TODO: support promotion of allocas with LEN params or shape operands
- if (hasLenParams() || hasShapeOperands())
+ // TODO: support promotion of dynamic allocas
+ if (isDynamic())
return {};
return {mlir::MemorySlot{getResult(), getAllocatedType()}};
@@ -2912,8 +2912,7 @@ bool fir::LoadOp::canUsesBeRemoved(
if (blockingUses.size() != 1)
return false;
mlir::Value blockingUse = (*blockingUses.begin())->get();
- return blockingUse == slot.ptr && getMemref() == slot.ptr &&
- getType() == slot.elemType;
+ return blockingUse == slot.ptr && getMemref() == slot.ptr;
}
mlir::DeletionKind fir::LoadOp::removeBlockingUses(
@@ -4342,7 +4341,7 @@ bool fir::StoreOp::canUsesBeRemoved(
return false;
mlir::Value blockingUse = (*blockingUses.begin())->get();
return blockingUse == slot.ptr && getMemref() == slot.ptr &&
- getValue() != slot.ptr && slot.elemType == getValue().getType();
+ getValue() != slot.ptr;
}
mlir::DeletionKind fir::StoreOp::removeBlockingUses(
More information about the flang-commits
mailing list