[flang-commits] [flang] [flang][FIR][Mem2Reg] Add supoort for FIR. (PR #172808)
Ming Yan via flang-commits
flang-commits at lists.llvm.org
Wed Dec 17 23:30:22 PST 2025
https://github.com/NexMing created https://github.com/llvm/llvm-project/pull/172808
This patch implements Mem2Reg interfaces for FIR.
>From 0f73278dc23e2ab8fba666fbebcedbde45289c45 Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Thu, 18 Dec 2025 14:55:08 +0800
Subject: [PATCH] [flang][FIR][Mem2Reg] Add supoort for FIR. This patch
implements Mem2Reg interfaces for FIR.
---
.../include/flang/Optimizer/Dialect/FIROps.td | 20 +++-
.../include/flang/Optimizer/Support/InitFIR.h | 1 +
flang/lib/Optimizer/Dialect/FIROps.cpp | 94 +++++++++++++++++++
flang/test/Fir/mem2reg.mlir | 68 ++++++++++++++
4 files changed, 178 insertions(+), 5 deletions(-)
create mode 100644 flang/test/Fir/mem2reg.mlir
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index cfce9fca504ec..7bfb304b973f5 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -17,6 +17,7 @@
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
+include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td"
include "flang/Optimizer/Dialect/FIRDialect.td"
@@ -80,7 +81,10 @@ def AnyRefOfConstantSizeAggregateType : TypeConstraint<
// Memory SSA operations
//===----------------------------------------------------------------------===//
-def fir_AllocaOp : fir_Op<"alloca", [AttrSizedOperandSegments]> {
+def fir_AllocaOp : fir_Op<"alloca", [
+ AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<PromotableAllocationOpInterface>
+]> {
let summary = "allocate storage for a temporary on the stack given a type";
let description = [{
This primitive operation is used to allocate an object on the stack. A
@@ -288,8 +292,11 @@ def fir_FreeMemOp : fir_Op<"freemem", [MemoryEffects<[MemFree]>]> {
let assemblyFormat = "$heapref attr-dict `:` qualified(type($heapref))";
}
-def fir_LoadOp : fir_OneResultOp<"load", [FirAliasTagOpInterface,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+def fir_LoadOp : fir_OneResultOp<"load", [
+ FirAliasTagOpInterface,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<PromotableMemOpInterface>
+]> {
let summary = "load a value from a memory reference";
let description = [{
Load a value from a memory reference into an ssa-value (virtual register).
@@ -319,8 +326,11 @@ def fir_LoadOp : fir_OneResultOp<"load", [FirAliasTagOpInterface,
}];
}
-def fir_StoreOp : fir_Op<"store", [FirAliasTagOpInterface,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+def fir_StoreOp : fir_Op<"store", [
+ FirAliasTagOpInterface,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<PromotableMemOpInterface>
+]> {
let summary = "store an SSA-value to a memory location";
let description = [{
diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h
index 67e9287ddad4f..41a979f97aece 100644
--- a/flang/include/flang/Optimizer/Support/InitFIR.h
+++ b/flang/include/flang/Optimizer/Support/InitFIR.h
@@ -129,6 +129,7 @@ inline void registerMLIRPassesForFortranTools() {
mlir::affine::registerAffineLoopTilingPass();
mlir::affine::registerAffineDataCopyGenerationPass();
+ mlir::registerMem2RegPass();
mlir::registerLowerAffinePass();
}
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 4e797d651cb7a..a004f7d9697ac 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -186,6 +186,33 @@ static mlir::Type wrapAllocaResultType(mlir::Type intype) {
return fir::ReferenceType::get(intype);
}
+llvm::SmallVector<mlir::MemorySlot> fir::AllocaOp::getPromotableSlots() {
+ // TODO: support promotion of allocas with LEN params or shape operands
+ if (hasLenParams() || hasShapeOperands())
+ return {};
+
+ return {mlir::MemorySlot{getResult(), getAllocatedType()}};
+}
+
+mlir::Value fir::AllocaOp::getDefaultValue(const mlir::MemorySlot &slot,
+ mlir::OpBuilder &builder) {
+ return fir::UndefOp::create(builder, getLoc(), slot.elemType);
+}
+
+void fir::AllocaOp::handleBlockArgument(const mlir::MemorySlot &slot,
+ mlir::BlockArgument argument,
+ mlir::OpBuilder &builder) {}
+
+std::optional<mlir::PromotableAllocationOpInterface>
+fir::AllocaOp::handlePromotionComplete(const mlir::MemorySlot &slot,
+ mlir::Value defaultValue,
+ mlir::OpBuilder &builder) {
+ if (defaultValue && defaultValue.use_empty())
+ defaultValue.getDefiningOp()->erase();
+ this->erase();
+ return std::nullopt;
+}
+
mlir::Type fir::AllocaOp::getAllocatedType() {
return mlir::cast<fir::ReferenceType>(getType()).getEleTy();
}
@@ -2861,6 +2888,40 @@ llvm::SmallVector<mlir::Attribute> fir::LenParamIndexOp::getAttributes() {
// LoadOp
//===----------------------------------------------------------------------===//
+bool fir::LoadOp::loadsFrom(const mlir::MemorySlot &slot) {
+ return getMemref() == slot.ptr;
+}
+
+bool fir::LoadOp::storesTo(const mlir::MemorySlot &slot) { return false; }
+
+mlir::Value fir::LoadOp::getStored(const mlir::MemorySlot &slot,
+ mlir::OpBuilder &builder,
+ mlir::Value reachingDef,
+ const mlir::DataLayout &dataLayout) {
+ llvm_unreachable("getStored should not be called on LoadOp");
+}
+
+bool fir::LoadOp::canUsesBeRemoved(
+ const mlir::MemorySlot &slot,
+ const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
+ mlir::SmallVectorImpl<mlir::OpOperand *> &newBlockingUses,
+ const mlir::DataLayout &dataLayout) {
+ if (blockingUses.size() != 1)
+ return false;
+ mlir::Value blockingUse = (*blockingUses.begin())->get();
+ return blockingUse == slot.ptr && getMemref() == slot.ptr &&
+ getType() == slot.elemType;
+}
+
+mlir::DeletionKind fir::LoadOp::removeBlockingUses(
+ const mlir::MemorySlot &slot,
+ const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
+ mlir::OpBuilder &builder, mlir::Value reachingDefinition,
+ const mlir::DataLayout &dataLayout) {
+ getResult().replaceAllUsesWith(reachingDefinition);
+ return mlir::DeletionKind::Delete;
+}
+
void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
mlir::Value refVal) {
if (!refVal) {
@@ -4256,6 +4317,39 @@ llvm::LogicalResult fir::SliceOp::verify() {
// StoreOp
//===----------------------------------------------------------------------===//
+bool fir::StoreOp::loadsFrom(const mlir::MemorySlot &slot) { return false; }
+
+bool fir::StoreOp::storesTo(const mlir::MemorySlot &slot) {
+ return getMemref() == slot.ptr;
+}
+
+mlir::Value fir::StoreOp::getStored(const mlir::MemorySlot &slot,
+ mlir::OpBuilder &builder,
+ mlir::Value reachingDef,
+ const mlir::DataLayout &dataLayout) {
+ return getValue();
+}
+
+bool fir::StoreOp::canUsesBeRemoved(
+ const mlir::MemorySlot &slot,
+ const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
+ mlir::SmallVectorImpl<mlir::OpOperand *> &newBlockingUses,
+ const mlir::DataLayout &dataLayout) {
+ if (blockingUses.size() != 1)
+ return false;
+ mlir::Value blockingUse = (*blockingUses.begin())->get();
+ return blockingUse == slot.ptr && getMemref() == slot.ptr &&
+ getValue() != slot.ptr && slot.elemType == getValue().getType();
+}
+
+mlir::DeletionKind fir::StoreOp::removeBlockingUses(
+ const mlir::MemorySlot &slot,
+ const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
+ mlir::OpBuilder &builder, mlir::Value reachingDefinition,
+ const mlir::DataLayout &dataLayout) {
+ return mlir::DeletionKind::Delete;
+}
+
mlir::Type fir::StoreOp::elementType(mlir::Type refType) {
return fir::dyn_cast_ptrEleTy(refType);
}
diff --git a/flang/test/Fir/mem2reg.mlir b/flang/test/Fir/mem2reg.mlir
new file mode 100644
index 0000000000000..25d114a55e1a4
--- /dev/null
+++ b/flang/test/Fir/mem2reg.mlir
@@ -0,0 +1,68 @@
+// RUN: fir-opt %s --allow-unregistered-dialect --mem2reg --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @basic() -> i32 {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 5 : i32
+// CHECK: return %[[CONSTANT_0]] : i32
+// CHECK: }
+func.func @basic() -> i32 {
+ %0 = arith.constant 5 : i32
+ %1 = fir.alloca i32
+ fir.store %0 to %1 : !fir.ref<i32>
+ %2 = fir.load %1 : !fir.ref<i32>
+ return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @default_value() -> i32 {
+// CHECK: %[[UNDEFINED_0:.*]] = fir.undefined i32
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 5 : i32
+// CHECK: return %[[UNDEFINED_0]] : i32
+// CHECK: }
+func.func @default_value() -> i32 {
+ %0 = arith.constant 5 : i32
+ %1 = fir.alloca i32
+ %2 = fir.load %1 : !fir.ref<i32>
+ fir.store %0 to %1 : !fir.ref<i32>
+ return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @basic_float() -> f32 {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 5.200000e+00 : f32
+// CHECK: return %[[CONSTANT_0]] : f32
+// CHECK: }
+func.func @basic_float() -> f32 {
+ %0 = arith.constant 5.2 : f32
+ %1 = fir.alloca f32
+ fir.store %0 to %1 : !fir.ref<f32>
+ %2 = fir.load %1 : !fir.ref<f32>
+ return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @cycle(
+// CHECK-SAME: %[[ARG0:.*]]: i64,
+// CHECK-SAME: %[[ARG1:.*]]: i1,
+// CHECK-SAME: %[[ARG2:.*]]: i64) {
+// CHECK: cf.cond_br %[[ARG1]], ^bb1(%[[ARG2]] : i64), ^bb2(%[[ARG2]] : i64)
+// CHECK: ^bb1(%[[VAL_0:.*]]: i64):
+// CHECK: "test.use"(%[[VAL_0]]) : (i64) -> ()
+// CHECK: cf.br ^bb2(%[[ARG0]] : i64)
+// CHECK: ^bb2(%[[VAL_1:.*]]: i64):
+// CHECK: cf.br ^bb1(%[[VAL_1]] : i64)
+// CHECK: }
+func.func @cycle(%arg0: i64, %arg1: i1, %arg2: i64) {
+ %alloca = fir.alloca i64
+ fir.store %arg2 to %alloca : !fir.ref<i64>
+ cf.cond_br %arg1, ^bb1, ^bb2
+^bb1:
+ %use = fir.load %alloca : !fir.ref<i64>
+ "test.use"(%use) : (i64) -> ()
+ fir.store %arg0 to %alloca : !fir.ref<i64>
+ cf.br ^bb2
+^bb2:
+ cf.br ^bb1
+}
More information about the flang-commits
mailing list