[flang-commits] [flang] [flang] add an optimization to remove fir.convert usage in FIRToMemRef (PR #187721)

via flang-commits flang-commits at lists.llvm.org
Fri Mar 20 08:46:57 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Susan Tan (ス-ザン タン) (SusanTan)

<details>
<summary>Changes</summary>

Add a peephole optimization for FIRToMemRef that forwards raw memref.load results directly into memref.store when types already match, removing the redundant convert round-trip.
This avoids value normalization in nested TRANSFER flows (e.g. transfer(transfer(i, .true.), 0)).

---
Full diff: https://github.com/llvm/llvm-project/pull/187721.diff


2 Files Affected:

- (modified) flang/lib/Optimizer/Transforms/FIRToMemRef.cpp (+15) 
- (modified) flang/test/Transforms/FIRToMemRef/logical.mlir (+19) 


``````````diff
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index 90b4b01e30f62..d29c1d06503d7 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -1010,6 +1010,21 @@ void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter,
   Value value = store.getValue();
   rewriter.setInsertionPointAfter(store);
 
+  // Small local optimization that avoids the round-trip:
+  //   %25 = memref.load ... : memref<i32>
+  //   %26 = fir.convert %25 : (i32) -> !fir.logical<4>   // from load rewrite
+  //   %27 = fir.convert %26 : (!fir.logical<4>) -> i32   // from store rewrite
+  //   memref.store %27, ... : memref<i32>
+  // which would normalize the loaded value to 1 and break TRANSFER-like flows,
+  // e.g. transfer(transfer(i, .true.), 0).
+  if (auto to = value.getDefiningOp<fir::ConvertOp>()) {
+    Value raw = to.getValue();
+    if (auto memrefTy = dyn_cast<MemRefType>(converted.getType()))
+      if (raw.getType() == memrefTy.getElementType() &&
+          isa_and_nonnull<memref::LoadOp>(raw.getDefiningOp()))
+        value = raw;
+  }
+
   if (isa<fir::LogicalType>(value.getType())) {
     Type convertedType = typeConverter.convertType(value.getType());
     value =
diff --git a/flang/test/Transforms/FIRToMemRef/logical.mlir b/flang/test/Transforms/FIRToMemRef/logical.mlir
index 75a9fac3e1e45..1c23944a8d75b 100644
--- a/flang/test/Transforms/FIRToMemRef/logical.mlir
+++ b/flang/test/Transforms/FIRToMemRef/logical.mlir
@@ -28,3 +28,22 @@ func.func @store_scalar(%arg0: !fir.ref<!fir.logical<4>>) {
   fir.store %2 to %1 : !fir.ref<!fir.logical<4>>
   return
 }
+
+// CHECK-LABEL: func.func @store_loaded_logical
+// CHECK:       [[DUMMY:%[0-9]+]] = fir.undefined !fir.dscope
+// CHECK:       [[SRC_DECLARE:%[0-9]+]] = fir.declare %arg0 dummy_scope [[DUMMY]]
+// CHECK:       [[DST_DECLARE:%[0-9]+]] = fir.declare %arg1 dummy_scope [[DUMMY]]
+// CHECK:       [[SRC_MEM:%[0-9]+]] = fir.convert [[SRC_DECLARE]] : (!fir.ref<!fir.logical<4>>) -> memref<i32>
+// CHECK:       [[LOAD:%[0-9]+]] = memref.load [[SRC_MEM]][] : memref<i32>
+// CHECK:       [[TOLOGICAL:%[0-9]+]] = fir.convert [[LOAD]] : (i32) -> !fir.logical<4>
+// CHECK:       [[DST_MEM:%[0-9]+]] = fir.convert [[DST_DECLARE]] : (!fir.ref<!fir.logical<4>>) -> memref<i32>
+// CHECK-NOT:   fir.convert [[TOLOGICAL]] : (!fir.logical<4>) -> i32
+// CHECK:       memref.store [[LOAD]], [[DST_MEM]][] : memref<i32>
+func.func @store_loaded_logical(%arg0: !fir.ref<!fir.logical<4>>, %arg1: !fir.ref<!fir.logical<4>>) {
+  %0 = fir.undefined !fir.dscope
+  %1 = fir.declare %arg0 dummy_scope %0 {uniq_name = "src"} : (!fir.ref<!fir.logical<4>>, !fir.dscope) -> !fir.ref<!fir.logical<4>>
+  %2 = fir.declare %arg1 dummy_scope %0 {uniq_name = "dst"} : (!fir.ref<!fir.logical<4>>, !fir.dscope) -> !fir.ref<!fir.logical<4>>
+  %3 = fir.load %1 : !fir.ref<!fir.logical<4>>
+  fir.store %3 to %2 : !fir.ref<!fir.logical<4>>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/187721


More information about the flang-commits mailing list