[flang-commits] [flang] [flang][FIR][Mem2Reg] Add supoort for FIR. (PR #172808)

Ming Yan via flang-commits flang-commits at lists.llvm.org
Wed Dec 17 23:30:22 PST 2025


https://github.com/NexMing created https://github.com/llvm/llvm-project/pull/172808

This patch implements Mem2Reg interfaces for FIR.

>From 0f73278dc23e2ab8fba666fbebcedbde45289c45 Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Thu, 18 Dec 2025 14:55:08 +0800
Subject: [PATCH] [flang][FIR][Mem2Reg] Add supoort for FIR. This patch
 implements Mem2Reg interfaces for FIR.

---
 .../include/flang/Optimizer/Dialect/FIROps.td | 20 +++-
 .../include/flang/Optimizer/Support/InitFIR.h |  1 +
 flang/lib/Optimizer/Dialect/FIROps.cpp        | 94 +++++++++++++++++++
 flang/test/Fir/mem2reg.mlir                   | 68 ++++++++++++++
 4 files changed, 178 insertions(+), 5 deletions(-)
 create mode 100644 flang/test/Fir/mem2reg.mlir

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index cfce9fca504ec..7bfb304b973f5 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -17,6 +17,7 @@
 include "mlir/Dialect/Arith/IR/ArithBase.td"
 include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
+include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
 include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td"
 include "flang/Optimizer/Dialect/FIRDialect.td"
@@ -80,7 +81,10 @@ def AnyRefOfConstantSizeAggregateType : TypeConstraint<
 // Memory SSA operations
 //===----------------------------------------------------------------------===//
 
-def fir_AllocaOp : fir_Op<"alloca", [AttrSizedOperandSegments]> {
+def fir_AllocaOp : fir_Op<"alloca", [
+  AttrSizedOperandSegments,
+  DeclareOpInterfaceMethods<PromotableAllocationOpInterface>
+]> {
   let summary = "allocate storage for a temporary on the stack given a type";
   let description = [{
     This primitive operation is used to allocate an object on the stack.  A
@@ -288,8 +292,11 @@ def fir_FreeMemOp : fir_Op<"freemem", [MemoryEffects<[MemFree]>]> {
   let assemblyFormat = "$heapref attr-dict `:` qualified(type($heapref))";
 }
 
-def fir_LoadOp : fir_OneResultOp<"load", [FirAliasTagOpInterface,
-                                          DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+def fir_LoadOp : fir_OneResultOp<"load", [
+  FirAliasTagOpInterface,
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  DeclareOpInterfaceMethods<PromotableMemOpInterface>
+]> {
   let summary = "load a value from a memory reference";
   let description = [{
     Load a value from a memory reference into an ssa-value (virtual register).
@@ -319,8 +326,11 @@ def fir_LoadOp : fir_OneResultOp<"load", [FirAliasTagOpInterface,
   }];
 }
 
-def fir_StoreOp : fir_Op<"store", [FirAliasTagOpInterface,
-                                   DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+def fir_StoreOp : fir_Op<"store", [
+  FirAliasTagOpInterface,
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  DeclareOpInterfaceMethods<PromotableMemOpInterface>
+]> {
   let summary = "store an SSA-value to a memory location";
 
   let description = [{
diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h
index 67e9287ddad4f..41a979f97aece 100644
--- a/flang/include/flang/Optimizer/Support/InitFIR.h
+++ b/flang/include/flang/Optimizer/Support/InitFIR.h
@@ -129,6 +129,7 @@ inline void registerMLIRPassesForFortranTools() {
   mlir::affine::registerAffineLoopTilingPass();
   mlir::affine::registerAffineDataCopyGenerationPass();
 
+  mlir::registerMem2RegPass();
   mlir::registerLowerAffinePass();
 }
 
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 4e797d651cb7a..a004f7d9697ac 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -186,6 +186,33 @@ static mlir::Type wrapAllocaResultType(mlir::Type intype) {
   return fir::ReferenceType::get(intype);
 }
 
+llvm::SmallVector<mlir::MemorySlot> fir::AllocaOp::getPromotableSlots() {
+  // TODO: support promotion of allocas with LEN params or shape operands
+  if (hasLenParams() || hasShapeOperands())
+    return {};
+
+  return {mlir::MemorySlot{getResult(), getAllocatedType()}};
+}
+
+mlir::Value fir::AllocaOp::getDefaultValue(const mlir::MemorySlot &slot,
+                                           mlir::OpBuilder &builder) {
+  return fir::UndefOp::create(builder, getLoc(), slot.elemType);
+}
+
+void fir::AllocaOp::handleBlockArgument(const mlir::MemorySlot &slot,
+                                        mlir::BlockArgument argument,
+                                        mlir::OpBuilder &builder) {}
+
+std::optional<mlir::PromotableAllocationOpInterface>
+fir::AllocaOp::handlePromotionComplete(const mlir::MemorySlot &slot,
+                                       mlir::Value defaultValue,
+                                       mlir::OpBuilder &builder) {
+  if (defaultValue && defaultValue.use_empty())
+    defaultValue.getDefiningOp()->erase();
+  this->erase();
+  return std::nullopt;
+}
+
 mlir::Type fir::AllocaOp::getAllocatedType() {
   return mlir::cast<fir::ReferenceType>(getType()).getEleTy();
 }
@@ -2861,6 +2888,40 @@ llvm::SmallVector<mlir::Attribute> fir::LenParamIndexOp::getAttributes() {
 // LoadOp
 //===----------------------------------------------------------------------===//
 
+bool fir::LoadOp::loadsFrom(const mlir::MemorySlot &slot) {
+  return getMemref() == slot.ptr;
+}
+
+bool fir::LoadOp::storesTo(const mlir::MemorySlot &slot) { return false; }
+
+mlir::Value fir::LoadOp::getStored(const mlir::MemorySlot &slot,
+                                   mlir::OpBuilder &builder,
+                                   mlir::Value reachingDef,
+                                   const mlir::DataLayout &dataLayout) {
+  llvm_unreachable("getStored should not be called on LoadOp");
+}
+
+bool fir::LoadOp::canUsesBeRemoved(
+    const mlir::MemorySlot &slot,
+    const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
+    mlir::SmallVectorImpl<mlir::OpOperand *> &newBlockingUses,
+    const mlir::DataLayout &dataLayout) {
+  if (blockingUses.size() != 1)
+    return false;
+  mlir::Value blockingUse = (*blockingUses.begin())->get();
+  return blockingUse == slot.ptr && getMemref() == slot.ptr &&
+         getType() == slot.elemType;
+}
+
+mlir::DeletionKind fir::LoadOp::removeBlockingUses(
+    const mlir::MemorySlot &slot,
+    const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
+    mlir::OpBuilder &builder, mlir::Value reachingDefinition,
+    const mlir::DataLayout &dataLayout) {
+  getResult().replaceAllUsesWith(reachingDefinition);
+  return mlir::DeletionKind::Delete;
+}
+
 void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
                         mlir::Value refVal) {
   if (!refVal) {
@@ -4256,6 +4317,39 @@ llvm::LogicalResult fir::SliceOp::verify() {
 // StoreOp
 //===----------------------------------------------------------------------===//
 
+bool fir::StoreOp::loadsFrom(const mlir::MemorySlot &slot) { return false; }
+
+bool fir::StoreOp::storesTo(const mlir::MemorySlot &slot) {
+  return getMemref() == slot.ptr;
+}
+
+mlir::Value fir::StoreOp::getStored(const mlir::MemorySlot &slot,
+                                    mlir::OpBuilder &builder,
+                                    mlir::Value reachingDef,
+                                    const mlir::DataLayout &dataLayout) {
+  return getValue();
+}
+
+bool fir::StoreOp::canUsesBeRemoved(
+    const mlir::MemorySlot &slot,
+    const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
+    mlir::SmallVectorImpl<mlir::OpOperand *> &newBlockingUses,
+    const mlir::DataLayout &dataLayout) {
+  if (blockingUses.size() != 1)
+    return false;
+  mlir::Value blockingUse = (*blockingUses.begin())->get();
+  return blockingUse == slot.ptr && getMemref() == slot.ptr &&
+         getValue() != slot.ptr && slot.elemType == getValue().getType();
+}
+
+mlir::DeletionKind fir::StoreOp::removeBlockingUses(
+    const mlir::MemorySlot &slot,
+    const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
+    mlir::OpBuilder &builder, mlir::Value reachingDefinition,
+    const mlir::DataLayout &dataLayout) {
+  return mlir::DeletionKind::Delete;
+}
+
 mlir::Type fir::StoreOp::elementType(mlir::Type refType) {
   return fir::dyn_cast_ptrEleTy(refType);
 }
diff --git a/flang/test/Fir/mem2reg.mlir b/flang/test/Fir/mem2reg.mlir
new file mode 100644
index 0000000000000..25d114a55e1a4
--- /dev/null
+++ b/flang/test/Fir/mem2reg.mlir
@@ -0,0 +1,68 @@
+// RUN: fir-opt %s --allow-unregistered-dialect --mem2reg --split-input-file | FileCheck %s
+
+// CHECK-LABEL:   func.func @basic() -> i32 {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant 5 : i32
+// CHECK:           return %[[CONSTANT_0]] : i32
+// CHECK:         }
+func.func @basic() -> i32 {
+  %0 = arith.constant 5 : i32
+  %1 = fir.alloca i32
+  fir.store %0 to %1 : !fir.ref<i32>
+  %2 = fir.load %1 : !fir.ref<i32>
+  return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @default_value() -> i32 {
+// CHECK:           %[[UNDEFINED_0:.*]] = fir.undefined i32
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant 5 : i32
+// CHECK:           return %[[UNDEFINED_0]] : i32
+// CHECK:         }
+func.func @default_value() -> i32 {
+  %0 = arith.constant 5 : i32
+  %1 = fir.alloca i32
+  %2 = fir.load %1 : !fir.ref<i32>
+  fir.store %0 to %1 : !fir.ref<i32>
+  return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @basic_float() -> f32 {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant 5.200000e+00 : f32
+// CHECK:           return %[[CONSTANT_0]] : f32
+// CHECK:         }
+func.func @basic_float() -> f32 {
+  %0 = arith.constant 5.2 : f32
+  %1 = fir.alloca f32
+  fir.store %0 to %1 : !fir.ref<f32>
+  %2 = fir.load %1 : !fir.ref<f32>
+  return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @cycle(
+// CHECK-SAME:                     %[[ARG0:.*]]: i64,
+// CHECK-SAME:                     %[[ARG1:.*]]: i1,
+// CHECK-SAME:                     %[[ARG2:.*]]: i64) {
+// CHECK:           cf.cond_br %[[ARG1]], ^bb1(%[[ARG2]] : i64), ^bb2(%[[ARG2]] : i64)
+// CHECK:         ^bb1(%[[VAL_0:.*]]: i64):
+// CHECK:           "test.use"(%[[VAL_0]]) : (i64) -> ()
+// CHECK:           cf.br ^bb2(%[[ARG0]] : i64)
+// CHECK:         ^bb2(%[[VAL_1:.*]]: i64):
+// CHECK:           cf.br ^bb1(%[[VAL_1]] : i64)
+// CHECK:         }
+func.func @cycle(%arg0: i64, %arg1: i1, %arg2: i64) {
+  %alloca = fir.alloca i64
+  fir.store %arg2 to %alloca : !fir.ref<i64>
+  cf.cond_br %arg1, ^bb1, ^bb2
+^bb1:
+  %use = fir.load %alloca : !fir.ref<i64>
+  "test.use"(%use) : (i64) -> ()
+  fir.store %arg0 to %alloca : !fir.ref<i64>
+  cf.br ^bb2
+^bb2:
+  cf.br ^bb1
+}



More information about the flang-commits mailing list