[flang-commits] [flang] [flang] use fir.bitcast for FIRToMemRef scalar reinterpretation (PR #188328)
via flang-commits
flang-commits at lists.llvm.org
Tue Mar 24 12:48:19 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Susan Tan (ス-ザン タン) (SusanTan)
<details>
<summary>Changes</summary>
Use fir.bitcast in FIR-to-MemRef casts so bit patterns are preserved (e.g. TRANSFER), while keeping fir.convert for memref/reference marshaling and non-bitcast-compatible cases.
---
Full diff: https://github.com/llvm/llvm-project/pull/188328.diff
2 Files Affected:
- (modified) flang/lib/Optimizer/Transforms/FIRToMemRef.cpp (+47-9)
- (modified) flang/test/Transforms/FIRToMemRef/logical.mlir (+2-2)
``````````diff
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index 447ee9c35f816..3b0b4bc007e61 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -383,6 +383,46 @@ static Value castTypeToIndexType(Value originalValue,
originalValue);
}
+static bool shouldUseBoundaryBitcast(mlir::Type fromTy, mlir::Type toTy) {
+ auto isBitcastCompatibleScalarType = [](mlir::Type ty) {
+ return mlir::isa<mlir::IntegerType, mlir::FloatType, fir::LogicalType>(
+ ty) ||
+ (mlir::isa<fir::CharacterType>(ty) &&
+ mlir::cast<fir::CharacterType>(ty).getLen() ==
+ fir::CharacterType::singleton());
+ };
+ auto getKnownScalarBitWidth = [](mlir::Type ty) -> std::optional<unsigned> {
+ if (auto intTy = mlir::dyn_cast<mlir::IntegerType>(ty))
+ return intTy.getWidth();
+ if (auto floatTy = mlir::dyn_cast<mlir::FloatType>(ty))
+ return floatTy.getWidth();
+ return std::nullopt;
+ };
+
+ if (fromTy == toTy)
+ return false;
+ const bool fromStd = fir::isa_std_type(fromTy);
+ const bool toStd = fir::isa_std_type(toTy);
+ if (fromStd == toStd)
+ return false;
+ if (!isBitcastCompatibleScalarType(fromTy) ||
+ !isBitcastCompatibleScalarType(toTy))
+ return false;
+ auto fromBits = getKnownScalarBitWidth(fromTy);
+ auto toBits = getKnownScalarBitWidth(toTy);
+ if (fromBits && toBits && *fromBits != *toBits)
+ return false;
+ return true;
+}
+
+static mlir::Value createTypeConversion(PatternRewriter &rewriter,
+ mlir::Location loc, mlir::Type toTy,
+ mlir::Value value) {
+ if (shouldUseBoundaryBitcast(value.getType(), toTy))
+ return fir::BitcastOp::create(rewriter, loc, toTy, value);
+ return fir::ConvertOp::create(rewriter, loc, toTy, value);
+}
+
FailureOr<SmallVector<Value>>
FIRToMemRef::getMemrefIndices(fir::ArrayCoorOp arrayCoorOp, Operation *memref,
PatternRewriter &rewriter, Value converted,
@@ -983,11 +1023,10 @@ void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter,
LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: new memref.load op:\n";
loadOp.dump(); assert(succeeded(verify(loadOp))));
- if (isa<fir::LogicalType>(originalType)) {
- Value logicalVal =
- fir::ConvertOp::create(rewriter, loadOp.getLoc(), originalType, loadOp);
- loadOp.getResult().replaceAllUsesExcept(logicalVal,
- logicalVal.getDefiningOp());
+ if (loadOp.getType() != originalType) {
+ Value castVal =
+ createTypeConversion(rewriter, loadOp.getLoc(), originalType, loadOp);
+ loadOp.getResult().replaceAllUsesExcept(castVal, castVal.getDefiningOp());
}
if (!isa<fir::LogicalType>(originalType))
@@ -1019,11 +1058,10 @@ void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter,
Value value = store.getValue();
rewriter.setInsertionPointAfter(store);
- if (isa<fir::LogicalType>(value.getType())) {
- Type convertedType = typeConverter.convertType(value.getType());
+ Type convertedType = typeConverter.convertType(value.getType());
+ if (convertedType != value.getType())
value =
- fir::ConvertOp::create(rewriter, store.getLoc(), convertedType, value);
- }
+ createTypeConversion(rewriter, store.getLoc(), convertedType, value);
Attribute attr = (store.getOperation())->getAttr("tbaa");
memref::StoreOp storeOp = rewriter.replaceOpWithNewOp<memref::StoreOp>(
diff --git a/flang/test/Transforms/FIRToMemRef/logical.mlir b/flang/test/Transforms/FIRToMemRef/logical.mlir
index 75a9fac3e1e45..948b8dcb2ae6e 100644
--- a/flang/test/Transforms/FIRToMemRef/logical.mlir
+++ b/flang/test/Transforms/FIRToMemRef/logical.mlir
@@ -4,7 +4,7 @@
// CHECK-NEXT: [[DECLARE:%[0-9]+]] = fir.declare %arg0 dummy_scope [[DUMMY]]
// CHECK-NEXT: [[CONVERT:%[0-9]+]] = fir.convert [[DECLARE]] : (!fir.ref<!fir.logical<4>>) -> memref<i32>
// CHECK-NEXT: [[LOAD:%[0-9]+]] = memref.load [[CONVERT]][] : memref<i32>
-// CHECK-NEXT: fir.convert [[LOAD]] : (i32) -> !fir.logical<4>
+// CHECK-NEXT: fir.bitcast [[LOAD]] : (i32) -> !fir.logical<4>
func.func @load_scalar(%arg0: !fir.ref<!fir.logical<4>>) {
%0 = fir.undefined !fir.dscope
%1 = fir.declare %arg0 dummy_scope %0 {uniq_name = "a"} : (!fir.ref<!fir.logical<4>>, !fir.dscope) -> !fir.ref<!fir.logical<4>>
@@ -18,7 +18,7 @@ func.func @load_scalar(%arg0: !fir.ref<!fir.logical<4>>) {
// CHECK: [[DECLARE:%[0-9]+]] = fir.declare %arg0 dummy_scope [[DUMMY]]
// CHECK-NEXT: [[CONVERT:%[0-9]+]] = fir.convert [[CONSTTRUE]] : (i1) -> !fir.logical<4>
// CHECK-NEXT: [[CONVERT1:%[0-9]+]] = fir.convert [[DECLARE]] : (!fir.ref<!fir.logical<4>>) -> memref<i32>
-// CHECK-NEXT: [[INT:%[0-9]+]] = fir.convert [[CONVERT]] : (!fir.logical<4>) -> i32
+// CHECK-NEXT: [[INT:%[0-9]+]] = fir.bitcast [[CONVERT]] : (!fir.logical<4>) -> i32
// CHECK-NEXT: memref.store [[INT]], [[CONVERT1]][] : memref<i32>
func.func @store_scalar(%arg0: !fir.ref<!fir.logical<4>>) {
%true = arith.constant true
``````````
</details>
https://github.com/llvm/llvm-project/pull/188328
More information about the flang-commits
mailing list