[flang-commits] [flang] [flang][fir] Support memref to memref fir.convert (PR #194954)

Ivan R. Ivanov via flang-commits flang-commits at lists.llvm.org
Thu Apr 30 03:18:25 PDT 2026


https://github.com/ivanradanov updated https://github.com/llvm/llvm-project/pull/194954

>From 380a0c8d844e7ac7c23ec917362d7faf440e424e Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <iivanov at nvidia.com>
Date: Wed, 29 Apr 2026 13:41:35 -0700
Subject: [PATCH 1/3] [flang][fir] Support memref to memref fir.convert

fir.convert of memref to memref can potentially arise due to a chain of
fir.convert between fir pointer types which get collapsed into a memref
to memref cast. Handle this as if we first convert to a pointer and then
convert the pointer to a memref.
---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp    |  65 +++++++-----
 flang/test/Fir/convert-memref-codegen.mlir | 117 ++++++++++++++++-----
 2 files changed, 129 insertions(+), 53 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 223cb2b2fb007..de023634be70f 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -67,6 +67,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/DebugLog.h"
 
 namespace fir {
 #define GEN_PASS_DEF_FIRTOLLVMLOWERING
@@ -934,30 +935,54 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
   llvm::LogicalResult
   matchAndRewrite(fir::ConvertOp convert, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
+    mlir::Location loc = convert.getLoc();
+
     auto fromFirTy = convert.getValue().getType();
     auto toFirTy = convert.getRes().getType();
 
-    // Handle conversions between pointer-like values and memref descriptors.
-    // These are produced by FIR-to-MemRef lowering and represent descriptor
-    // conversion rather than pure value conversions.
-    if (auto memRefTy = mlir::dyn_cast<mlir::MemRefType>(toFirTy)) {
-      mlir::Location loc = convert.getLoc();
-      mlir::Value basePtr = adaptor.getValue();
-      assert(basePtr && "null base pointer");
+    auto toMemRefTy = mlir::dyn_cast<mlir::MemRefType>(toFirTy);
+    auto fromMemRefTy = mlir::dyn_cast<mlir::MemRefType>(fromFirTy);
 
+    auto *firConv =
+        static_cast<const fir::LLVMTypeConverter *>(this->getTypeConverter());
+    assert(firConv && "expected non-null LLVMTypeConverter");
+
+    auto isStaticLayoutAndShape = [](mlir::MemRefType memRefTy) {
       auto [strides, offset] = memRefTy.getStridesAndOffset();
       bool hasStaticLayout =
           mlir::ShapedType::isStatic(offset) &&
           llvm::none_of(strides, mlir::ShapedType::isDynamic);
+      return hasStaticLayout && memRefTy.hasStaticShape();
+    };
+    auto getAlignedPtr = [&rewriter, &loc, &firConv](
+                             mlir::Value memRefVal, mlir::MemRefType memRefTy) {
+      auto alignedPtr =
+          mlir::LLVM::ExtractValueOp::create(rewriter, loc, memRefVal, 1);
+      auto offset =
+          mlir::LLVM::ExtractValueOp::create(rewriter, loc, memRefVal, 2);
+      mlir::Type elementType = firConv->convertType(memRefTy.getElementType());
+      auto gepOp = mlir::LLVM::GEPOp::create(rewriter, loc,
+                                             alignedPtr.getType(), elementType,
+                                             alignedPtr, offset.getResult());
+      return gepOp;
+    };
+
+    // Handle conversions between pointer-like values and memref descriptors.
+    // These are produced by FIR-to-MemRef lowering and represent descriptor
+    // conversion rather than pure value conversions.
+    if (toMemRefTy) {
+      mlir::Value basePtr = adaptor.getValue();
+      assert(basePtr && "null base pointer");
 
-      auto *firConv =
-          static_cast<const fir::LLVMTypeConverter *>(this->getTypeConverter());
-      assert(firConv && "expected non-null LLVMTypeConverter");
+      // If the from type is also a memref we need to extract its buffer
+      // pointer.
+      if (fromMemRefTy)
+        basePtr = getAlignedPtr(basePtr, fromMemRefTy);
 
-      if (memRefTy.hasStaticShape() && hasStaticLayout) {
+      if (isStaticLayoutAndShape(toMemRefTy)) {
         // Static shape and layout: build a fully-populated descriptor.
         mlir::Value memrefDesc = mlir::MemRefDescriptor::fromStaticShape(
-            rewriter, loc, *firConv, memRefTy, basePtr);
+            rewriter, loc, *firConv, toMemRefTy, basePtr);
         rewriter.replaceOp(convert, memrefDesc);
         return mlir::success();
       }
@@ -965,7 +990,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
       // Dynamic shape or layout: create an LLVM memref descriptor and insert
       // the base pointer field, letting the rest of the fields be populated
       // by subsequent lowering.
-      mlir::Type llvmMemRefTy = firConv->convertType(memRefTy);
+      mlir::Type llvmMemRefTy = firConv->convertType(toMemRefTy);
       auto undef = mlir::LLVM::UndefOp::create(rewriter, loc, llvmMemRefTy);
       auto insert =
           mlir::LLVM::InsertValueOp::create(rewriter, loc, undef, basePtr, 1);
@@ -973,20 +998,11 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
       return mlir::success();
     }
 
-    if (auto memRefTy = mlir::dyn_cast<mlir::MemRefType>(fromFirTy)) {
+    if (fromMemRefTy) {
       // Legalize conversions *from* memref descriptors to pointer-like values
       // by extracting the underlying buffer pointer from the descriptor.
-      mlir::Location loc = convert.getLoc();
       mlir::Value base = adaptor.getValue();
-      auto alignedPtr =
-          mlir::LLVM::ExtractValueOp::create(rewriter, loc, base, 1);
-      auto offset = mlir::LLVM::ExtractValueOp::create(rewriter, loc, base, 2);
-      mlir::Type elementType =
-          this->getTypeConverter()->convertType(memRefTy.getElementType());
-      auto gepOp = mlir::LLVM::GEPOp::create(rewriter, loc,
-                                             alignedPtr.getType(), elementType,
-                                             alignedPtr, offset.getResult());
-      rewriter.replaceOp(convert, gepOp);
+      rewriter.replaceOp(convert, getAlignedPtr(base, fromMemRefTy));
       return mlir::success();
     }
 
@@ -999,7 +1015,6 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
       return mlir::success();
     }
 
-    auto loc = convert.getLoc();
     auto i1Type = mlir::IntegerType::get(convert.getContext(), 1);
 
     if (mlir::isa<fir::RecordType>(toFirTy)) {
diff --git a/flang/test/Fir/convert-memref-codegen.mlir b/flang/test/Fir/convert-memref-codegen.mlir
index 5bc55647e76e1..0d8d3e6093253 100644
--- a/flang/test/Fir/convert-memref-codegen.mlir
+++ b/flang/test/Fir/convert-memref-codegen.mlir
@@ -1,36 +1,97 @@
-// RUN: fir-opt --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s -o - | FileCheck %s
+// RUN: fir-opt --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" --split-input-file %s | FileCheck %s
 
 // This test ensures that the FIR CodeGen ConvertOpConversion
 // properly lowers fir.convert when either the source or the destination
 // type is a memref.
 
-module {
-  // CHECK-LABEL: llvm.func @memref_to_ref_convert(
-  // Reconstruct the memref descriptor from the expanded LLVM arguments.
-  // CHECK:         %[[POISON0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
-  // CHECK:         %[[DESC0:.*]] = llvm.insertvalue %arg0, %[[POISON0]][0] : !llvm.struct<(ptr, ptr, i64)>
-  // CHECK:         %[[DESC1:.*]] = llvm.insertvalue %arg1, %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64)>
-  // CHECK:         %[[DESC:.*]] = llvm.insertvalue %arg2, %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64)>
-  //
-  // Lower the fir.convert from memref<f32> to !fir.ref<f32> by extracting
-  // the buffer pointer from the descriptor.
-  // CHECK:         %[[ALIGNED:.*]] = llvm.extractvalue %[[DESC]][1] : !llvm.struct<(ptr, ptr, i64)>
-  // CHECK:         %[[OFF:.*]] = llvm.extractvalue %[[DESC]][2] : !llvm.struct<(ptr, ptr, i64)>
-  // CHECK:         %[[BUF:.*]] = llvm.getelementptr %[[ALIGNED]][%[[OFF]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-  //
-  // The second fir.convert (from !fir.ref<f32> back to memref<f32>) lowering
-  // CHECK:         %[[POISON1:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
-  // CHECK:         %[[DESC2:.*]] = llvm.insertvalue %[[BUF]], %[[POISON1]][0] : !llvm.struct<(ptr, ptr, i64)>
-  // CHECK:         %[[DESC3:.*]] = llvm.insertvalue %[[BUF]], %[[DESC2]][1] : !llvm.struct<(ptr, ptr, i64)>
-  // CHECK:         %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
-  // CHECK:         %[[DESC4:.*]] = llvm.insertvalue %[[ZERO]], %[[DESC3]][2] : !llvm.struct<(ptr, ptr, i64)>
-  //
-  // CHECK-NOT:     fir.convert
-  func.func @memref_to_ref_convert(%arg0: memref<f32>) {
-    %0 = fir.convert %arg0 : (memref<f32>) -> !fir.ref<f32>
-    %1 = fir.convert %0 : (!fir.ref<f32>) -> memref<f32>
-    return
-  }
+// CHECK-LABEL: llvm.func @memref_to_ref_convert(
+// Reconstruct the memref descriptor from the expanded LLVM arguments.
+// CHECK:         %[[POISON0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:         %[[DESC0:.*]] = llvm.insertvalue %arg0, %[[POISON0]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:         %[[DESC1:.*]] = llvm.insertvalue %arg1, %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:         %[[DESC:.*]] = llvm.insertvalue %arg2, %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64)>
+//
+// Lower the fir.convert from memref<f32> to !fir.ref<f32> by extracting
+// the buffer pointer from the descriptor.
+// CHECK:         %[[ALIGNED:.*]] = llvm.extractvalue %[[DESC]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:         %[[OFF:.*]] = llvm.extractvalue %[[DESC]][2] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:         %[[BUF:.*]] = llvm.getelementptr %[[ALIGNED]][%[[OFF]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+//
+// The second fir.convert (from !fir.ref<f32> back to memref<f32>) lowering
+// CHECK:         %[[POISON1:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:         %[[DESC2:.*]] = llvm.insertvalue %[[BUF]], %[[POISON1]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:         %[[DESC3:.*]] = llvm.insertvalue %[[BUF]], %[[DESC2]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:         %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:         %[[DESC4:.*]] = llvm.insertvalue %[[ZERO]], %[[DESC3]][2] : !llvm.struct<(ptr, ptr, i64)>
+//
+// CHECK-NOT:     fir.convert
+
+func.func @memref_to_ref_convert(%arg0: memref<f32>) {
+  %0 = fir.convert %arg0 : (memref<f32>) -> !fir.ref<f32>
+  %1 = fir.convert %0 : (!fir.ref<f32>) -> memref<f32>
+  return
 }
 
+// -----
+
+// CHECK-LABEL:   llvm.func @memref_to_memref_convert(
+// CHECK-SAME:      %[[ARG0:[^:]*]]: !llvm.ptr,
+// CHECK-SAME:      %[[ARG1:[^:]*]]: !llvm.ptr,
+// CHECK-SAME:      %[[ARG2:[^:]*]]: i64) -> !llvm.struct<(ptr, ptr, i64)> {
+// CHECK:           %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[ARG0]], %[[MLIR_0]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[ARG1]], %[[INSERTVALUE_0]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_2:.*]] = llvm.insertvalue %[[ARG2]], %[[INSERTVALUE_1]][2] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[EXTRACTVALUE_0:.*]] = llvm.extractvalue %[[INSERTVALUE_2]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[EXTRACTVALUE_1:.*]] = llvm.extractvalue %[[INSERTVALUE_2]][2] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[GETELEMENTPTR_0:.*]] = llvm.getelementptr %[[EXTRACTVALUE_0]]{{\[}}%[[EXTRACTVALUE_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK:           %[[MLIR_1:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_3:.*]] = llvm.insertvalue %[[GETELEMENTPTR_0]], %[[MLIR_1]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_4:.*]] = llvm.insertvalue %[[GETELEMENTPTR_0]], %[[INSERTVALUE_3]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[MLIR_2:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[INSERTVALUE_5:.*]] = llvm.insertvalue %[[MLIR_2]], %[[INSERTVALUE_4]][2] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           llvm.return %[[INSERTVALUE_5]] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:         }
+
+func.func @memref_to_memref_convert(%arg0: memref<f32>) -> memref<i1> {
+  %0 = fir.convert %arg0 : (memref<f32>) -> memref<i1>
+  return %0 : memref<i1>
+}
 
+// -----
+
+// CHECK-LABEL:   llvm.func @memref_to_memref_convert(
+// CHECK-SAME:      %[[ARG0:[^:]*]]: !llvm.ptr,
+// CHECK-SAME:      %[[ARG1:[^:]*]]: !llvm.ptr,
+// CHECK-SAME:      %[[ARG2:[^:]*]]: i64,
+// CHECK-SAME:      %[[ARG3:[^:]*]]: i64,
+// CHECK-SAME:      %[[ARG4:[^:]*]]: i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> {
+// CHECK:           %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK:           %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[ARG0]], %[[MLIR_0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK:           %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[ARG1]], %[[INSERTVALUE_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK:           %[[INSERTVALUE_2:.*]] = llvm.insertvalue %[[ARG2]], %[[INSERTVALUE_1]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK:           %[[INSERTVALUE_3:.*]] = llvm.insertvalue %[[ARG3]], %[[INSERTVALUE_2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK:           %[[INSERTVALUE_4:.*]] = llvm.insertvalue %[[ARG4]], %[[INSERTVALUE_3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_0:.*]] = llvm.extractvalue %[[INSERTVALUE_4]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_1:.*]] = llvm.extractvalue %[[INSERTVALUE_4]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK:           %[[GETELEMENTPTR_0:.*]] = llvm.getelementptr %[[EXTRACTVALUE_0]]{{\[}}%[[EXTRACTVALUE_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK:           %[[MLIR_1:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[INSERTVALUE_5:.*]] = llvm.insertvalue %[[GETELEMENTPTR_0]], %[[MLIR_1]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[INSERTVALUE_6:.*]] = llvm.insertvalue %[[GETELEMENTPTR_0]], %[[INSERTVALUE_5]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[MLIR_2:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[INSERTVALUE_7:.*]] = llvm.insertvalue %[[MLIR_2]], %[[INSERTVALUE_6]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[MLIR_3:.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK:           %[[INSERTVALUE_8:.*]] = llvm.insertvalue %[[MLIR_3]], %[[INSERTVALUE_7]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[MLIR_4:.*]] = llvm.mlir.constant(3 : index) : i64
+// CHECK:           %[[INSERTVALUE_9:.*]] = llvm.insertvalue %[[MLIR_4]], %[[INSERTVALUE_8]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[MLIR_5:.*]] = llvm.mlir.constant(3 : index) : i64
+// CHECK:           %[[INSERTVALUE_10:.*]] = llvm.insertvalue %[[MLIR_5]], %[[INSERTVALUE_9]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[MLIR_6:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK:           %[[INSERTVALUE_11:.*]] = llvm.insertvalue %[[MLIR_6]], %[[INSERTVALUE_10]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           llvm.return %[[INSERTVALUE_11]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:         }
+
+func.func @memref_to_memref_convert(%arg0: memref<3xf32>) -> memref<2x3xi1> {
+  %0 = fir.convert %arg0 : (memref<3xf32>) -> memref<2x3xi1>
+  return %0 : memref<2x3xi1>
+}

>From c1de9562ac4d17a3bf8282a187519c114ca95eed Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <iivanov at nvidia.com>
Date: Wed, 29 Apr 2026 13:49:40 -0700
Subject: [PATCH 2/3] func name

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index de023634be70f..66b46645ce378 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -954,8 +954,8 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
           llvm::none_of(strides, mlir::ShapedType::isDynamic);
       return hasStaticLayout && memRefTy.hasStaticShape();
     };
-    auto getAlignedPtr = [&rewriter, &loc, &firConv](
-                             mlir::Value memRefVal, mlir::MemRefType memRefTy) {
+    auto getBufferPtr = [&rewriter, &loc, &firConv](mlir::Value memRefVal,
+                                                    mlir::MemRefType memRefTy) {
       auto alignedPtr =
           mlir::LLVM::ExtractValueOp::create(rewriter, loc, memRefVal, 1);
       auto offset =
@@ -977,7 +977,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
       // If the from type is also a memref we need to extract its buffer
       // pointer.
       if (fromMemRefTy)
-        basePtr = getAlignedPtr(basePtr, fromMemRefTy);
+        basePtr = getBufferPtr(basePtr, fromMemRefTy);
 
       if (isStaticLayoutAndShape(toMemRefTy)) {
         // Static shape and layout: build a fully-populated descriptor.
@@ -1002,7 +1002,7 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
       // Legalize conversions *from* memref descriptors to pointer-like values
       // by extracting the underlying buffer pointer from the descriptor.
       mlir::Value base = adaptor.getValue();
-      rewriter.replaceOp(convert, getAlignedPtr(base, fromMemRefTy));
+      rewriter.replaceOp(convert, getBufferPtr(base, fromMemRefTy));
       return mlir::success();
     }
 

>From a42bd83e1cda2366820f65363ec1f2624882f04b Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <iivanov at nvidia.com>
Date: Thu, 30 Apr 2026 03:12:37 -0700
Subject: [PATCH 3/3] fix

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp | 14 ++++++--------
 1 file changed, 6 insertions(+), 8 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 66b46645ce378..4087e146a7a93 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -947,13 +947,6 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
         static_cast<const fir::LLVMTypeConverter *>(this->getTypeConverter());
     assert(firConv && "expected non-null LLVMTypeConverter");
 
-    auto isStaticLayoutAndShape = [](mlir::MemRefType memRefTy) {
-      auto [strides, offset] = memRefTy.getStridesAndOffset();
-      bool hasStaticLayout =
-          mlir::ShapedType::isStatic(offset) &&
-          llvm::none_of(strides, mlir::ShapedType::isDynamic);
-      return hasStaticLayout && memRefTy.hasStaticShape();
-    };
     auto getBufferPtr = [&rewriter, &loc, &firConv](mlir::Value memRefVal,
                                                     mlir::MemRefType memRefTy) {
       auto alignedPtr =
@@ -979,7 +972,12 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
       if (fromMemRefTy)
         basePtr = getBufferPtr(basePtr, fromMemRefTy);
 
-      if (isStaticLayoutAndShape(toMemRefTy)) {
+      auto [strides, offset] = memRefTy.getStridesAndOffset();
+      bool hasStaticLayout =
+          mlir::ShapedType::isStatic(offset) &&
+          llvm::none_of(strides, mlir::ShapedType::isDynamic);
+
+      if (toMemRefTy.hasStaticShape() && hasStaticLayout) {
         // Static shape and layout: build a fully-populated descriptor.
         mlir::Value memrefDesc = mlir::MemRefDescriptor::fromStaticShape(
             rewriter, loc, *firConv, toMemRefTy, basePtr);



More information about the flang-commits mailing list