[flang-commits] [flang] 8dfe85b - [flang][fir] Support memref to memref fir.convert (#194954)
via flang-commits
flang-commits at lists.llvm.org
Thu Apr 30 06:12:38 PDT 2026
Author: Ivan R. Ivanov
Date: 2026-04-30T13:12:32Z
New Revision: 8dfe85b1e782eeb5006720fa0c20d0e4b34dc787
URL: https://github.com/llvm/llvm-project/commit/8dfe85b1e782eeb5006720fa0c20d0e4b34dc787
DIFF: https://github.com/llvm/llvm-project/commit/8dfe85b1e782eeb5006720fa0c20d0e4b34dc787.diff
LOG: [flang][fir] Support memref to memref fir.convert (#194954)
fir.convert of memref to memref can potentially arise due to a chain of
fir.convert between fir pointer types which get collapsed into a memref
to memref cast. Handle this as if we first convert to a pointer and then
convert the pointer to a memref.
Added:
Modified:
flang/lib/Optimizer/CodeGen/CodeGen.cpp
flang/test/Fir/convert-memref-codegen.mlir
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 223cb2b2fb007..6ba69c93fe7eb 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -934,30 +934,52 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
llvm::LogicalResult
matchAndRewrite(fir::ConvertOp convert, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
+ mlir::Location loc = convert.getLoc();
+
auto fromFirTy = convert.getValue().getType();
auto toFirTy = convert.getRes().getType();
+ auto toMemRefTy = mlir::dyn_cast<mlir::MemRefType>(toFirTy);
+ auto fromMemRefTy = mlir::dyn_cast<mlir::MemRefType>(fromFirTy);
+
+ auto *firConv =
+ static_cast<const fir::LLVMTypeConverter *>(this->getTypeConverter());
+ assert(firConv && "expected non-null LLVMTypeConverter");
+
+ auto getBufferPtr = [&rewriter, &loc, &firConv](mlir::Value memRefVal,
+ mlir::MemRefType memRefTy) {
+ auto alignedPtr =
+ mlir::LLVM::ExtractValueOp::create(rewriter, loc, memRefVal, 1);
+ auto offset =
+ mlir::LLVM::ExtractValueOp::create(rewriter, loc, memRefVal, 2);
+ mlir::Type elementType = firConv->convertType(memRefTy.getElementType());
+ auto gepOp = mlir::LLVM::GEPOp::create(rewriter, loc,
+ alignedPtr.getType(), elementType,
+ alignedPtr, offset.getResult());
+ return gepOp;
+ };
+
// Handle conversions between pointer-like values and memref descriptors.
// These are produced by FIR-to-MemRef lowering and represent descriptor
// conversion rather than pure value conversions.
- if (auto memRefTy = mlir::dyn_cast<mlir::MemRefType>(toFirTy)) {
- mlir::Location loc = convert.getLoc();
+ if (toMemRefTy) {
mlir::Value basePtr = adaptor.getValue();
assert(basePtr && "null base pointer");
- auto [strides, offset] = memRefTy.getStridesAndOffset();
+ // If the from type is also a memref we need to extract its buffer
+ // pointer.
+ if (fromMemRefTy)
+ basePtr = getBufferPtr(basePtr, fromMemRefTy);
+
+ auto [strides, offset] = toMemRefTy.getStridesAndOffset();
bool hasStaticLayout =
mlir::ShapedType::isStatic(offset) &&
llvm::none_of(strides, mlir::ShapedType::isDynamic);
- auto *firConv =
- static_cast<const fir::LLVMTypeConverter *>(this->getTypeConverter());
- assert(firConv && "expected non-null LLVMTypeConverter");
-
- if (memRefTy.hasStaticShape() && hasStaticLayout) {
+ if (toMemRefTy.hasStaticShape() && hasStaticLayout) {
// Static shape and layout: build a fully-populated descriptor.
mlir::Value memrefDesc = mlir::MemRefDescriptor::fromStaticShape(
- rewriter, loc, *firConv, memRefTy, basePtr);
+ rewriter, loc, *firConv, toMemRefTy, basePtr);
rewriter.replaceOp(convert, memrefDesc);
return mlir::success();
}
@@ -965,7 +987,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
// Dynamic shape or layout: create an LLVM memref descriptor and insert
// the base pointer field, letting the rest of the fields be populated
// by subsequent lowering.
- mlir::Type llvmMemRefTy = firConv->convertType(memRefTy);
+ mlir::Type llvmMemRefTy = firConv->convertType(toMemRefTy);
auto undef = mlir::LLVM::UndefOp::create(rewriter, loc, llvmMemRefTy);
auto insert =
mlir::LLVM::InsertValueOp::create(rewriter, loc, undef, basePtr, 1);
@@ -973,20 +995,11 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
return mlir::success();
}
- if (auto memRefTy = mlir::dyn_cast<mlir::MemRefType>(fromFirTy)) {
+ if (fromMemRefTy) {
// Legalize conversions *from* memref descriptors to pointer-like values
// by extracting the underlying buffer pointer from the descriptor.
- mlir::Location loc = convert.getLoc();
mlir::Value base = adaptor.getValue();
- auto alignedPtr =
- mlir::LLVM::ExtractValueOp::create(rewriter, loc, base, 1);
- auto offset = mlir::LLVM::ExtractValueOp::create(rewriter, loc, base, 2);
- mlir::Type elementType =
- this->getTypeConverter()->convertType(memRefTy.getElementType());
- auto gepOp = mlir::LLVM::GEPOp::create(rewriter, loc,
- alignedPtr.getType(), elementType,
- alignedPtr, offset.getResult());
- rewriter.replaceOp(convert, gepOp);
+ rewriter.replaceOp(convert, getBufferPtr(base, fromMemRefTy));
return mlir::success();
}
@@ -999,7 +1012,6 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
return mlir::success();
}
- auto loc = convert.getLoc();
auto i1Type = mlir::IntegerType::get(convert.getContext(), 1);
if (mlir::isa<fir::RecordType>(toFirTy)) {
diff --git a/flang/test/Fir/convert-memref-codegen.mlir b/flang/test/Fir/convert-memref-codegen.mlir
index 5bc55647e76e1..0d8d3e6093253 100644
--- a/flang/test/Fir/convert-memref-codegen.mlir
+++ b/flang/test/Fir/convert-memref-codegen.mlir
@@ -1,36 +1,97 @@
-// RUN: fir-opt --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s -o - | FileCheck %s
+// RUN: fir-opt --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" --split-input-file %s | FileCheck %s
// This test ensures that the FIR CodeGen ConvertOpConversion
// properly lowers fir.convert when either the source or the destination
// type is a memref.
-module {
- // CHECK-LABEL: llvm.func @memref_to_ref_convert(
- // Reconstruct the memref descriptor from the expanded LLVM arguments.
- // CHECK: %[[POISON0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
- // CHECK: %[[DESC0:.*]] = llvm.insertvalue %arg0, %[[POISON0]][0] : !llvm.struct<(ptr, ptr, i64)>
- // CHECK: %[[DESC1:.*]] = llvm.insertvalue %arg1, %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64)>
- // CHECK: %[[DESC:.*]] = llvm.insertvalue %arg2, %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64)>
- //
- // Lower the fir.convert from memref<f32> to !fir.ref<f32> by extracting
- // the buffer pointer from the descriptor.
- // CHECK: %[[ALIGNED:.*]] = llvm.extractvalue %[[DESC]][1] : !llvm.struct<(ptr, ptr, i64)>
- // CHECK: %[[OFF:.*]] = llvm.extractvalue %[[DESC]][2] : !llvm.struct<(ptr, ptr, i64)>
- // CHECK: %[[BUF:.*]] = llvm.getelementptr %[[ALIGNED]][%[[OFF]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- //
- // The second fir.convert (from !fir.ref<f32> back to memref<f32>) lowering
- // CHECK: %[[POISON1:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
- // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[BUF]], %[[POISON1]][0] : !llvm.struct<(ptr, ptr, i64)>
- // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[BUF]], %[[DESC2]][1] : !llvm.struct<(ptr, ptr, i64)>
- // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
- // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[ZERO]], %[[DESC3]][2] : !llvm.struct<(ptr, ptr, i64)>
- //
- // CHECK-NOT: fir.convert
- func.func @memref_to_ref_convert(%arg0: memref<f32>) {
- %0 = fir.convert %arg0 : (memref<f32>) -> !fir.ref<f32>
- %1 = fir.convert %0 : (!fir.ref<f32>) -> memref<f32>
- return
- }
+// CHECK-LABEL: llvm.func @memref_to_ref_convert(
+// Reconstruct the memref descriptor from the expanded LLVM arguments.
+// CHECK: %[[POISON0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[DESC0:.*]] = llvm.insertvalue %arg0, %[[POISON0]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[DESC1:.*]] = llvm.insertvalue %arg1, %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[DESC:.*]] = llvm.insertvalue %arg2, %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64)>
+//
+// Lower the fir.convert from memref<f32> to !fir.ref<f32> by extracting
+// the buffer pointer from the descriptor.
+// CHECK: %[[ALIGNED:.*]] = llvm.extractvalue %[[DESC]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[OFF:.*]] = llvm.extractvalue %[[DESC]][2] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[BUF:.*]] = llvm.getelementptr %[[ALIGNED]][%[[OFF]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+//
+// The second fir.convert (from !fir.ref<f32> back to memref<f32>) lowering
+// CHECK: %[[POISON1:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[BUF]], %[[POISON1]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[BUF]], %[[DESC2]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[ZERO]], %[[DESC3]][2] : !llvm.struct<(ptr, ptr, i64)>
+//
+// CHECK-NOT: fir.convert
+
+func.func @memref_to_ref_convert(%arg0: memref<f32>) {
+ %0 = fir.convert %arg0 : (memref<f32>) -> !fir.ref<f32>
+ %1 = fir.convert %0 : (!fir.ref<f32>) -> memref<f32>
+ return
}
+// -----
+
+// CHECK-LABEL: llvm.func @memref_to_memref_convert(
+// CHECK-SAME: %[[ARG0:[^:]*]]: !llvm.ptr,
+// CHECK-SAME: %[[ARG1:[^:]*]]: !llvm.ptr,
+// CHECK-SAME: %[[ARG2:[^:]*]]: i64) -> !llvm.struct<(ptr, ptr, i64)> {
+// CHECK: %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[ARG0]], %[[MLIR_0]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[ARG1]], %[[INSERTVALUE_0]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[INSERTVALUE_2:.*]] = llvm.insertvalue %[[ARG2]], %[[INSERTVALUE_1]][2] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[EXTRACTVALUE_0:.*]] = llvm.extractvalue %[[INSERTVALUE_2]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[EXTRACTVALUE_1:.*]] = llvm.extractvalue %[[INSERTVALUE_2]][2] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[GETELEMENTPTR_0:.*]] = llvm.getelementptr %[[EXTRACTVALUE_0]]{{\[}}%[[EXTRACTVALUE_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: %[[MLIR_1:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[INSERTVALUE_3:.*]] = llvm.insertvalue %[[GETELEMENTPTR_0]], %[[MLIR_1]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[INSERTVALUE_4:.*]] = llvm.insertvalue %[[GETELEMENTPTR_0]], %[[INSERTVALUE_3]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[MLIR_2:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[INSERTVALUE_5:.*]] = llvm.insertvalue %[[MLIR_2]], %[[INSERTVALUE_4]][2] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: llvm.return %[[INSERTVALUE_5]] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: }
+
+func.func @memref_to_memref_convert(%arg0: memref<f32>) -> memref<i1> {
+ %0 = fir.convert %arg0 : (memref<f32>) -> memref<i1>
+ return %0 : memref<i1>
+}
+// -----
+
+// CHECK-LABEL: llvm.func @memref_to_memref_convert(
+// CHECK-SAME: %[[ARG0:[^:]*]]: !llvm.ptr,
+// CHECK-SAME: %[[ARG1:[^:]*]]: !llvm.ptr,
+// CHECK-SAME: %[[ARG2:[^:]*]]: i64,
+// CHECK-SAME: %[[ARG3:[^:]*]]: i64,
+// CHECK-SAME: %[[ARG4:[^:]*]]: i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> {
+// CHECK: %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[ARG0]], %[[MLIR_0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[ARG1]], %[[INSERTVALUE_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[INSERTVALUE_2:.*]] = llvm.insertvalue %[[ARG2]], %[[INSERTVALUE_1]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[INSERTVALUE_3:.*]] = llvm.insertvalue %[[ARG3]], %[[INSERTVALUE_2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[INSERTVALUE_4:.*]] = llvm.insertvalue %[[ARG4]], %[[INSERTVALUE_3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[EXTRACTVALUE_0:.*]] = llvm.extractvalue %[[INSERTVALUE_4]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[EXTRACTVALUE_1:.*]] = llvm.extractvalue %[[INSERTVALUE_4]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[GETELEMENTPTR_0:.*]] = llvm.getelementptr %[[EXTRACTVALUE_0]]{{\[}}%[[EXTRACTVALUE_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: %[[MLIR_1:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[INSERTVALUE_5:.*]] = llvm.insertvalue %[[GETELEMENTPTR_0]], %[[MLIR_1]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[INSERTVALUE_6:.*]] = llvm.insertvalue %[[GETELEMENTPTR_0]], %[[INSERTVALUE_5]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[MLIR_2:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[INSERTVALUE_7:.*]] = llvm.insertvalue %[[MLIR_2]], %[[INSERTVALUE_6]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[MLIR_3:.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK: %[[INSERTVALUE_8:.*]] = llvm.insertvalue %[[MLIR_3]], %[[INSERTVALUE_7]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[MLIR_4:.*]] = llvm.mlir.constant(3 : index) : i64
+// CHECK: %[[INSERTVALUE_9:.*]] = llvm.insertvalue %[[MLIR_4]], %[[INSERTVALUE_8]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[MLIR_5:.*]] = llvm.mlir.constant(3 : index) : i64
+// CHECK: %[[INSERTVALUE_10:.*]] = llvm.insertvalue %[[MLIR_5]], %[[INSERTVALUE_9]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[MLIR_6:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[INSERTVALUE_11:.*]] = llvm.insertvalue %[[MLIR_6]], %[[INSERTVALUE_10]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.return %[[INSERTVALUE_11]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: }
+
+func.func @memref_to_memref_convert(%arg0: memref<3xf32>) -> memref<2x3xi1> {
+ %0 = fir.convert %arg0 : (memref<3xf32>) -> memref<2x3xi1>
+ return %0 : memref<2x3xi1>
+}
More information about the flang-commits
mailing list