[flang-commits] [flang] [flang] use fir.bitcast for FIRToMemRef scalar reinterpretation (PR #188328)

via flang-commits flang-commits at lists.llvm.org
Tue Mar 24 12:48:19 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Susan Tan (ス-ザン タン) (SusanTan)

<details>
<summary>Changes</summary>

Use fir.bitcast in FIR-to-MemRef casts so bit patterns are preserved (e.g. TRANSFER), while keeping fir.convert for memref/reference marshaling and non-bitcast-compatible cases. 

---
Full diff: https://github.com/llvm/llvm-project/pull/188328.diff


2 Files Affected:

- (modified) flang/lib/Optimizer/Transforms/FIRToMemRef.cpp (+47-9) 
- (modified) flang/test/Transforms/FIRToMemRef/logical.mlir (+2-2) 


``````````diff
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index 447ee9c35f816..3b0b4bc007e61 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -383,6 +383,46 @@ static Value castTypeToIndexType(Value originalValue,
                                     originalValue);
 }
 
+static bool shouldUseBoundaryBitcast(mlir::Type fromTy, mlir::Type toTy) {
+  auto isBitcastCompatibleScalarType = [](mlir::Type ty) {
+    return mlir::isa<mlir::IntegerType, mlir::FloatType, fir::LogicalType>(
+               ty) ||
+           (mlir::isa<fir::CharacterType>(ty) &&
+            mlir::cast<fir::CharacterType>(ty).getLen() ==
+                fir::CharacterType::singleton());
+  };
+  auto getKnownScalarBitWidth = [](mlir::Type ty) -> std::optional<unsigned> {
+    if (auto intTy = mlir::dyn_cast<mlir::IntegerType>(ty))
+      return intTy.getWidth();
+    if (auto floatTy = mlir::dyn_cast<mlir::FloatType>(ty))
+      return floatTy.getWidth();
+    return std::nullopt;
+  };
+
+  if (fromTy == toTy)
+    return false;
+  const bool fromStd = fir::isa_std_type(fromTy);
+  const bool toStd = fir::isa_std_type(toTy);
+  if (fromStd == toStd)
+    return false;
+  if (!isBitcastCompatibleScalarType(fromTy) ||
+      !isBitcastCompatibleScalarType(toTy))
+    return false;
+  auto fromBits = getKnownScalarBitWidth(fromTy);
+  auto toBits = getKnownScalarBitWidth(toTy);
+  if (fromBits && toBits && *fromBits != *toBits)
+    return false;
+  return true;
+}
+
+static mlir::Value createTypeConversion(PatternRewriter &rewriter,
+                                        mlir::Location loc, mlir::Type toTy,
+                                        mlir::Value value) {
+  if (shouldUseBoundaryBitcast(value.getType(), toTy))
+    return fir::BitcastOp::create(rewriter, loc, toTy, value);
+  return fir::ConvertOp::create(rewriter, loc, toTy, value);
+}
+
 FailureOr<SmallVector<Value>>
 FIRToMemRef::getMemrefIndices(fir::ArrayCoorOp arrayCoorOp, Operation *memref,
                               PatternRewriter &rewriter, Value converted,
@@ -983,11 +1023,10 @@ void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter,
   LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: new memref.load op:\n";
              loadOp.dump(); assert(succeeded(verify(loadOp))));
 
-  if (isa<fir::LogicalType>(originalType)) {
-    Value logicalVal =
-        fir::ConvertOp::create(rewriter, loadOp.getLoc(), originalType, loadOp);
-    loadOp.getResult().replaceAllUsesExcept(logicalVal,
-                                            logicalVal.getDefiningOp());
+  if (loadOp.getType() != originalType) {
+    Value castVal =
+        createTypeConversion(rewriter, loadOp.getLoc(), originalType, loadOp);
+    loadOp.getResult().replaceAllUsesExcept(castVal, castVal.getDefiningOp());
   }
 
   if (!isa<fir::LogicalType>(originalType))
@@ -1019,11 +1058,10 @@ void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter,
   Value value = store.getValue();
   rewriter.setInsertionPointAfter(store);
 
-  if (isa<fir::LogicalType>(value.getType())) {
-    Type convertedType = typeConverter.convertType(value.getType());
+  Type convertedType = typeConverter.convertType(value.getType());
+  if (convertedType != value.getType())
     value =
-        fir::ConvertOp::create(rewriter, store.getLoc(), convertedType, value);
-  }
+        createTypeConversion(rewriter, store.getLoc(), convertedType, value);
 
   Attribute attr = (store.getOperation())->getAttr("tbaa");
   memref::StoreOp storeOp = rewriter.replaceOpWithNewOp<memref::StoreOp>(
diff --git a/flang/test/Transforms/FIRToMemRef/logical.mlir b/flang/test/Transforms/FIRToMemRef/logical.mlir
index 75a9fac3e1e45..948b8dcb2ae6e 100644
--- a/flang/test/Transforms/FIRToMemRef/logical.mlir
+++ b/flang/test/Transforms/FIRToMemRef/logical.mlir
@@ -4,7 +4,7 @@
 // CHECK-NEXT:  [[DECLARE:%[0-9]+]] = fir.declare %arg0 dummy_scope [[DUMMY]]
 // CHECK-NEXT:  [[CONVERT:%[0-9]+]] = fir.convert [[DECLARE]] : (!fir.ref<!fir.logical<4>>) -> memref<i32>
 // CHECK-NEXT:  [[LOAD:%[0-9]+]] = memref.load [[CONVERT]][] : memref<i32>
-// CHECK-NEXT:  fir.convert [[LOAD]] : (i32) -> !fir.logical<4>
+// CHECK-NEXT:  fir.bitcast [[LOAD]] : (i32) -> !fir.logical<4>
 func.func @load_scalar(%arg0: !fir.ref<!fir.logical<4>>) {
   %0 = fir.undefined !fir.dscope
   %1 = fir.declare %arg0 dummy_scope %0 {uniq_name = "a"} : (!fir.ref<!fir.logical<4>>, !fir.dscope) -> !fir.ref<!fir.logical<4>>
@@ -18,7 +18,7 @@ func.func @load_scalar(%arg0: !fir.ref<!fir.logical<4>>) {
 // CHECK:       [[DECLARE:%[0-9]+]] = fir.declare %arg0 dummy_scope [[DUMMY]]
 // CHECK-NEXT:  [[CONVERT:%[0-9]+]] = fir.convert [[CONSTTRUE]] : (i1) -> !fir.logical<4>
 // CHECK-NEXT:  [[CONVERT1:%[0-9]+]] = fir.convert [[DECLARE]] : (!fir.ref<!fir.logical<4>>) -> memref<i32>
-// CHECK-NEXT:  [[INT:%[0-9]+]] = fir.convert [[CONVERT]] : (!fir.logical<4>) -> i32
+// CHECK-NEXT:  [[INT:%[0-9]+]] = fir.bitcast [[CONVERT]] : (!fir.logical<4>) -> i32
 // CHECK-NEXT:  memref.store [[INT]], [[CONVERT1]][] : memref<i32>
 func.func @store_scalar(%arg0: !fir.ref<!fir.logical<4>>) {
   %true = arith.constant true

``````````

</details>


https://github.com/llvm/llvm-project/pull/188328


More information about the flang-commits mailing list