[flang-commits] [flang] [flang][fir] Support memref to memref fir.convert (PR #194954)
Ivan R. Ivanov via flang-commits
flang-commits at lists.llvm.org
Thu Apr 30 03:18:25 PDT 2026
https://github.com/ivanradanov updated https://github.com/llvm/llvm-project/pull/194954
>From 380a0c8d844e7ac7c23ec917362d7faf440e424e Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <iivanov at nvidia.com>
Date: Wed, 29 Apr 2026 13:41:35 -0700
Subject: [PATCH 1/3] [flang][fir] Support memref to memref fir.convert
fir.convert of memref to memref can potentially arise due to a chain of
fir.convert between fir pointer types which get collapsed into a memref
to memref cast. Handle this as if we first convert to a pointer and then
convert the pointer to a memref.
---
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 65 +++++++-----
flang/test/Fir/convert-memref-codegen.mlir | 117 ++++++++++++++++-----
2 files changed, 129 insertions(+), 53 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 223cb2b2fb007..de023634be70f 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -67,6 +67,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/DebugLog.h"
namespace fir {
#define GEN_PASS_DEF_FIRTOLLVMLOWERING
@@ -934,30 +935,54 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
llvm::LogicalResult
matchAndRewrite(fir::ConvertOp convert, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
+ mlir::Location loc = convert.getLoc();
+
auto fromFirTy = convert.getValue().getType();
auto toFirTy = convert.getRes().getType();
- // Handle conversions between pointer-like values and memref descriptors.
- // These are produced by FIR-to-MemRef lowering and represent descriptor
- // conversion rather than pure value conversions.
- if (auto memRefTy = mlir::dyn_cast<mlir::MemRefType>(toFirTy)) {
- mlir::Location loc = convert.getLoc();
- mlir::Value basePtr = adaptor.getValue();
- assert(basePtr && "null base pointer");
+ auto toMemRefTy = mlir::dyn_cast<mlir::MemRefType>(toFirTy);
+ auto fromMemRefTy = mlir::dyn_cast<mlir::MemRefType>(fromFirTy);
+ auto *firConv =
+ static_cast<const fir::LLVMTypeConverter *>(this->getTypeConverter());
+ assert(firConv && "expected non-null LLVMTypeConverter");
+
+ auto isStaticLayoutAndShape = [](mlir::MemRefType memRefTy) {
auto [strides, offset] = memRefTy.getStridesAndOffset();
bool hasStaticLayout =
mlir::ShapedType::isStatic(offset) &&
llvm::none_of(strides, mlir::ShapedType::isDynamic);
+ return hasStaticLayout && memRefTy.hasStaticShape();
+ };
+ auto getAlignedPtr = [&rewriter, &loc, &firConv](
+ mlir::Value memRefVal, mlir::MemRefType memRefTy) {
+ auto alignedPtr =
+ mlir::LLVM::ExtractValueOp::create(rewriter, loc, memRefVal, 1);
+ auto offset =
+ mlir::LLVM::ExtractValueOp::create(rewriter, loc, memRefVal, 2);
+ mlir::Type elementType = firConv->convertType(memRefTy.getElementType());
+ auto gepOp = mlir::LLVM::GEPOp::create(rewriter, loc,
+ alignedPtr.getType(), elementType,
+ alignedPtr, offset.getResult());
+ return gepOp;
+ };
+
+ // Handle conversions between pointer-like values and memref descriptors.
+ // These are produced by FIR-to-MemRef lowering and represent descriptor
+ // conversion rather than pure value conversions.
+ if (toMemRefTy) {
+ mlir::Value basePtr = adaptor.getValue();
+ assert(basePtr && "null base pointer");
- auto *firConv =
- static_cast<const fir::LLVMTypeConverter *>(this->getTypeConverter());
- assert(firConv && "expected non-null LLVMTypeConverter");
+ // If the from type is also a memref we need to extract its buffer
+ // pointer.
+ if (fromMemRefTy)
+ basePtr = getAlignedPtr(basePtr, fromMemRefTy);
- if (memRefTy.hasStaticShape() && hasStaticLayout) {
+ if (isStaticLayoutAndShape(toMemRefTy)) {
// Static shape and layout: build a fully-populated descriptor.
mlir::Value memrefDesc = mlir::MemRefDescriptor::fromStaticShape(
- rewriter, loc, *firConv, memRefTy, basePtr);
+ rewriter, loc, *firConv, toMemRefTy, basePtr);
rewriter.replaceOp(convert, memrefDesc);
return mlir::success();
}
@@ -965,7 +990,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
// Dynamic shape or layout: create an LLVM memref descriptor and insert
// the base pointer field, letting the rest of the fields be populated
// by subsequent lowering.
- mlir::Type llvmMemRefTy = firConv->convertType(memRefTy);
+ mlir::Type llvmMemRefTy = firConv->convertType(toMemRefTy);
auto undef = mlir::LLVM::UndefOp::create(rewriter, loc, llvmMemRefTy);
auto insert =
mlir::LLVM::InsertValueOp::create(rewriter, loc, undef, basePtr, 1);
@@ -973,20 +998,11 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
return mlir::success();
}
- if (auto memRefTy = mlir::dyn_cast<mlir::MemRefType>(fromFirTy)) {
+ if (fromMemRefTy) {
// Legalize conversions *from* memref descriptors to pointer-like values
// by extracting the underlying buffer pointer from the descriptor.
- mlir::Location loc = convert.getLoc();
mlir::Value base = adaptor.getValue();
- auto alignedPtr =
- mlir::LLVM::ExtractValueOp::create(rewriter, loc, base, 1);
- auto offset = mlir::LLVM::ExtractValueOp::create(rewriter, loc, base, 2);
- mlir::Type elementType =
- this->getTypeConverter()->convertType(memRefTy.getElementType());
- auto gepOp = mlir::LLVM::GEPOp::create(rewriter, loc,
- alignedPtr.getType(), elementType,
- alignedPtr, offset.getResult());
- rewriter.replaceOp(convert, gepOp);
+ rewriter.replaceOp(convert, getAlignedPtr(base, fromMemRefTy));
return mlir::success();
}
@@ -999,7 +1015,6 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
return mlir::success();
}
- auto loc = convert.getLoc();
auto i1Type = mlir::IntegerType::get(convert.getContext(), 1);
if (mlir::isa<fir::RecordType>(toFirTy)) {
diff --git a/flang/test/Fir/convert-memref-codegen.mlir b/flang/test/Fir/convert-memref-codegen.mlir
index 5bc55647e76e1..0d8d3e6093253 100644
--- a/flang/test/Fir/convert-memref-codegen.mlir
+++ b/flang/test/Fir/convert-memref-codegen.mlir
@@ -1,36 +1,97 @@
-// RUN: fir-opt --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s -o - | FileCheck %s
+// RUN: fir-opt --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" --split-input-file %s | FileCheck %s
// This test ensures that the FIR CodeGen ConvertOpConversion
// properly lowers fir.convert when either the source or the destination
// type is a memref.
-module {
- // CHECK-LABEL: llvm.func @memref_to_ref_convert(
- // Reconstruct the memref descriptor from the expanded LLVM arguments.
- // CHECK: %[[POISON0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
- // CHECK: %[[DESC0:.*]] = llvm.insertvalue %arg0, %[[POISON0]][0] : !llvm.struct<(ptr, ptr, i64)>
- // CHECK: %[[DESC1:.*]] = llvm.insertvalue %arg1, %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64)>
- // CHECK: %[[DESC:.*]] = llvm.insertvalue %arg2, %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64)>
- //
- // Lower the fir.convert from memref<f32> to !fir.ref<f32> by extracting
- // the buffer pointer from the descriptor.
- // CHECK: %[[ALIGNED:.*]] = llvm.extractvalue %[[DESC]][1] : !llvm.struct<(ptr, ptr, i64)>
- // CHECK: %[[OFF:.*]] = llvm.extractvalue %[[DESC]][2] : !llvm.struct<(ptr, ptr, i64)>
- // CHECK: %[[BUF:.*]] = llvm.getelementptr %[[ALIGNED]][%[[OFF]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
- //
- // The second fir.convert (from !fir.ref<f32> back to memref<f32>) lowering
- // CHECK: %[[POISON1:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
- // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[BUF]], %[[POISON1]][0] : !llvm.struct<(ptr, ptr, i64)>
- // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[BUF]], %[[DESC2]][1] : !llvm.struct<(ptr, ptr, i64)>
- // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
- // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[ZERO]], %[[DESC3]][2] : !llvm.struct<(ptr, ptr, i64)>
- //
- // CHECK-NOT: fir.convert
- func.func @memref_to_ref_convert(%arg0: memref<f32>) {
- %0 = fir.convert %arg0 : (memref<f32>) -> !fir.ref<f32>
- %1 = fir.convert %0 : (!fir.ref<f32>) -> memref<f32>
- return
- }
+// CHECK-LABEL: llvm.func @memref_to_ref_convert(
+// Reconstruct the memref descriptor from the expanded LLVM arguments.
+// CHECK: %[[POISON0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[DESC0:.*]] = llvm.insertvalue %arg0, %[[POISON0]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[DESC1:.*]] = llvm.insertvalue %arg1, %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[DESC:.*]] = llvm.insertvalue %arg2, %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64)>
+//
+// Lower the fir.convert from memref<f32> to !fir.ref<f32> by extracting
+// the buffer pointer from the descriptor.
+// CHECK: %[[ALIGNED:.*]] = llvm.extractvalue %[[DESC]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[OFF:.*]] = llvm.extractvalue %[[DESC]][2] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[BUF:.*]] = llvm.getelementptr %[[ALIGNED]][%[[OFF]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+//
+// The second fir.convert (from !fir.ref<f32> back to memref<f32>) lowering
+// CHECK: %[[POISON1:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[BUF]], %[[POISON1]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[BUF]], %[[DESC2]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[ZERO]], %[[DESC3]][2] : !llvm.struct<(ptr, ptr, i64)>
+//
+// CHECK-NOT: fir.convert
+
+func.func @memref_to_ref_convert(%arg0: memref<f32>) {
+ %0 = fir.convert %arg0 : (memref<f32>) -> !fir.ref<f32>
+ %1 = fir.convert %0 : (!fir.ref<f32>) -> memref<f32>
+ return
}
+// -----
+
+// CHECK-LABEL: llvm.func @memref_to_memref_convert(
+// CHECK-SAME: %[[ARG0:[^:]*]]: !llvm.ptr,
+// CHECK-SAME: %[[ARG1:[^:]*]]: !llvm.ptr,
+// CHECK-SAME: %[[ARG2:[^:]*]]: i64) -> !llvm.struct<(ptr, ptr, i64)> {
+// CHECK: %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[ARG0]], %[[MLIR_0]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[ARG1]], %[[INSERTVALUE_0]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[INSERTVALUE_2:.*]] = llvm.insertvalue %[[ARG2]], %[[INSERTVALUE_1]][2] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[EXTRACTVALUE_0:.*]] = llvm.extractvalue %[[INSERTVALUE_2]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[EXTRACTVALUE_1:.*]] = llvm.extractvalue %[[INSERTVALUE_2]][2] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[GETELEMENTPTR_0:.*]] = llvm.getelementptr %[[EXTRACTVALUE_0]]{{\[}}%[[EXTRACTVALUE_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: %[[MLIR_1:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[INSERTVALUE_3:.*]] = llvm.insertvalue %[[GETELEMENTPTR_0]], %[[MLIR_1]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[INSERTVALUE_4:.*]] = llvm.insertvalue %[[GETELEMENTPTR_0]], %[[INSERTVALUE_3]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[MLIR_2:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[INSERTVALUE_5:.*]] = llvm.insertvalue %[[MLIR_2]], %[[INSERTVALUE_4]][2] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: llvm.return %[[INSERTVALUE_5]] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: }
+
+func.func @memref_to_memref_convert(%arg0: memref<f32>) -> memref<i1> {
+ %0 = fir.convert %arg0 : (memref<f32>) -> memref<i1>
+ return %0 : memref<i1>
+}
+// -----
+
+// CHECK-LABEL: llvm.func @memref_to_memref_convert(
+// CHECK-SAME: %[[ARG0:[^:]*]]: !llvm.ptr,
+// CHECK-SAME: %[[ARG1:[^:]*]]: !llvm.ptr,
+// CHECK-SAME: %[[ARG2:[^:]*]]: i64,
+// CHECK-SAME: %[[ARG3:[^:]*]]: i64,
+// CHECK-SAME: %[[ARG4:[^:]*]]: i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> {
+// CHECK: %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[ARG0]], %[[MLIR_0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[ARG1]], %[[INSERTVALUE_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[INSERTVALUE_2:.*]] = llvm.insertvalue %[[ARG2]], %[[INSERTVALUE_1]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[INSERTVALUE_3:.*]] = llvm.insertvalue %[[ARG3]], %[[INSERTVALUE_2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[INSERTVALUE_4:.*]] = llvm.insertvalue %[[ARG4]], %[[INSERTVALUE_3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[EXTRACTVALUE_0:.*]] = llvm.extractvalue %[[INSERTVALUE_4]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[EXTRACTVALUE_1:.*]] = llvm.extractvalue %[[INSERTVALUE_4]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[GETELEMENTPTR_0:.*]] = llvm.getelementptr %[[EXTRACTVALUE_0]]{{\[}}%[[EXTRACTVALUE_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: %[[MLIR_1:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[INSERTVALUE_5:.*]] = llvm.insertvalue %[[GETELEMENTPTR_0]], %[[MLIR_1]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[INSERTVALUE_6:.*]] = llvm.insertvalue %[[GETELEMENTPTR_0]], %[[INSERTVALUE_5]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[MLIR_2:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[INSERTVALUE_7:.*]] = llvm.insertvalue %[[MLIR_2]], %[[INSERTVALUE_6]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[MLIR_3:.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK: %[[INSERTVALUE_8:.*]] = llvm.insertvalue %[[MLIR_3]], %[[INSERTVALUE_7]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[MLIR_4:.*]] = llvm.mlir.constant(3 : index) : i64
+// CHECK: %[[INSERTVALUE_9:.*]] = llvm.insertvalue %[[MLIR_4]], %[[INSERTVALUE_8]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[MLIR_5:.*]] = llvm.mlir.constant(3 : index) : i64
+// CHECK: %[[INSERTVALUE_10:.*]] = llvm.insertvalue %[[MLIR_5]], %[[INSERTVALUE_9]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[MLIR_6:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[INSERTVALUE_11:.*]] = llvm.insertvalue %[[MLIR_6]], %[[INSERTVALUE_10]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: llvm.return %[[INSERTVALUE_11]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: }
+
+func.func @memref_to_memref_convert(%arg0: memref<3xf32>) -> memref<2x3xi1> {
+ %0 = fir.convert %arg0 : (memref<3xf32>) -> memref<2x3xi1>
+ return %0 : memref<2x3xi1>
+}
>From c1de9562ac4d17a3bf8282a187519c114ca95eed Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <iivanov at nvidia.com>
Date: Wed, 29 Apr 2026 13:49:40 -0700
Subject: [PATCH 2/3] func name
---
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index de023634be70f..66b46645ce378 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -954,8 +954,8 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
llvm::none_of(strides, mlir::ShapedType::isDynamic);
return hasStaticLayout && memRefTy.hasStaticShape();
};
- auto getAlignedPtr = [&rewriter, &loc, &firConv](
- mlir::Value memRefVal, mlir::MemRefType memRefTy) {
+ auto getBufferPtr = [&rewriter, &loc, &firConv](mlir::Value memRefVal,
+ mlir::MemRefType memRefTy) {
auto alignedPtr =
mlir::LLVM::ExtractValueOp::create(rewriter, loc, memRefVal, 1);
auto offset =
@@ -977,7 +977,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
// If the from type is also a memref we need to extract its buffer
// pointer.
if (fromMemRefTy)
- basePtr = getAlignedPtr(basePtr, fromMemRefTy);
+ basePtr = getBufferPtr(basePtr, fromMemRefTy);
if (isStaticLayoutAndShape(toMemRefTy)) {
// Static shape and layout: build a fully-populated descriptor.
@@ -1002,7 +1002,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
// Legalize conversions *from* memref descriptors to pointer-like values
// by extracting the underlying buffer pointer from the descriptor.
mlir::Value base = adaptor.getValue();
- rewriter.replaceOp(convert, getAlignedPtr(base, fromMemRefTy));
+ rewriter.replaceOp(convert, getBufferPtr(base, fromMemRefTy));
return mlir::success();
}
>From a42bd83e1cda2366820f65363ec1f2624882f04b Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <iivanov at nvidia.com>
Date: Thu, 30 Apr 2026 03:12:37 -0700
Subject: [PATCH 3/3] fix
---
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 14 ++++++--------
1 file changed, 6 insertions(+), 8 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 66b46645ce378..4087e146a7a93 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -947,13 +947,6 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
static_cast<const fir::LLVMTypeConverter *>(this->getTypeConverter());
assert(firConv && "expected non-null LLVMTypeConverter");
- auto isStaticLayoutAndShape = [](mlir::MemRefType memRefTy) {
- auto [strides, offset] = memRefTy.getStridesAndOffset();
- bool hasStaticLayout =
- mlir::ShapedType::isStatic(offset) &&
- llvm::none_of(strides, mlir::ShapedType::isDynamic);
- return hasStaticLayout && memRefTy.hasStaticShape();
- };
auto getBufferPtr = [&rewriter, &loc, &firConv](mlir::Value memRefVal,
mlir::MemRefType memRefTy) {
auto alignedPtr =
@@ -979,7 +972,12 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
if (fromMemRefTy)
basePtr = getBufferPtr(basePtr, fromMemRefTy);
- if (isStaticLayoutAndShape(toMemRefTy)) {
+ auto [strides, offset] = memRefTy.getStridesAndOffset();
+ bool hasStaticLayout =
+ mlir::ShapedType::isStatic(offset) &&
+ llvm::none_of(strides, mlir::ShapedType::isDynamic);
+
+ if (toMemRefTy.hasStaticShape() && hasStaticLayout) {
// Static shape and layout: build a fully-populated descriptor.
mlir::Value memrefDesc = mlir::MemRefDescriptor::fromStaticShape(
rewriter, loc, *firConv, toMemRefTy, basePtr);
More information about the flang-commits
mailing list