[flang-commits] [flang] [FIR] add a fir.shape_extents operation (PR #199361)

via flang-commits flang-commits at lists.llvm.org
Mon Jun 8 09:47:21 PDT 2026


https://github.com/yebinchon updated https://github.com/llvm/llvm-project/pull/199361

>From a89731080414d08fb317344fdf6becc950d094bf Mon Sep 17 00:00:00 2001
From: Yebin Chon <ychon at nvidia.com>
Date: Sat, 23 May 2026 09:32:01 -0700
Subject: [PATCH 1/6] [FIR] add a shape_extents operation that takes a
 fir.shape and unpacks it into n SSA values (one per dim)

---
 .../include/flang/Optimizer/Dialect/FIROps.td | 35 +++++++++
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       | 69 +++++++++++++---
 flang/lib/Optimizer/CodeGen/TypeConverter.cpp |  6 ++
 flang/lib/Optimizer/Dialect/FIROps.cpp        | 45 +++++++++++
 .../lib/Optimizer/Transforms/FIRToMemRef.cpp  | 78 +++++++++++++++----
 flang/test/Fir/convert-to-llvm-invalid.fir    | 12 ---
 flang/test/Fir/shape-extents.mlir             | 42 ++++++++++
 7 files changed, 251 insertions(+), 36 deletions(-)
 create mode 100644 flang/test/Fir/shape-extents.mlir

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index ec42855761bcb..3d0ac4398e7df 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2087,6 +2087,41 @@ def fir_ShapeOp : fir_Op<"shape", [Pure]> {
   let builders = [OpBuilder<(ins "mlir::ValueRange":$extents)>];
 }
 
+def fir_ShapeExtentsOp : fir_Op<"shape_extents", [Pure]> {
+  let summary = "unpack a `!fir.shape` into per-dimension extent SSA values";
+
+  let description = [{
+    Takes a single abstract `!fir.shape<n>` value and yields `n` integer SSA
+    results, one per dimension extent in Fortran row-to-column order. This is
+    intended for lowering when extent values are needed but the defining
+    `fir.shape` is not visible (for example when a shape is forwarded through
+    control flow as a block argument).
+
+    When the operand is the result of `fir.shape`, a canonicalization may fold
+    this operation to the original extent operands.
+
+    ```
+      %e0, %e1 = fir.shape_extents %sh : (!fir.shape<2>) -> (index, index)
+    ```
+  }];
+
+  let arguments = (ins fir_ShapeType:$shape);
+
+  let results = (outs Variadic<AnyIntegerType>:$extents);
+
+  let assemblyFormat = [{
+    $shape attr-dict `:` functional-type(operands, results)
+  }];
+
+  let hasVerifier = 1;
+
+  let hasCanonicalizer = 1;
+
+  let skipDefaultBuilders = 1;
+
+  let builders = [OpBuilder<(ins "mlir::Value":$shape)>];
+}
+
 def fir_ShapeShiftOp : fir_Op<"shape_shift", [Pure]> {
 
   let summary = [{
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 6b1acba393170..523ca462dde78 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -4577,8 +4577,57 @@ struct MustBeDeadConversion : public fir::FIROpConversion<FromOp> {
   }
 };
 
-struct ShapeOpConversion : public MustBeDeadConversion<fir::ShapeOp> {
-  using MustBeDeadConversion::MustBeDeadConversion;
+// Shape can now be lowered into an llvm struct
+struct ShapeOpConversion : public fir::FIROpConversion<fir::ShapeOp> {
+  using FIROpConversion::FIROpConversion;
+
+  llvm::LogicalResult
+  matchAndRewrite(fir::ShapeOp op, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    if (op->use_empty()) {
+      rewriter.eraseOp(op);
+      return mlir::success();
+    }
+    auto loc = op.getLoc();
+    auto shapeTy = mlir::cast<fir::ShapeType>(op.getType());
+    mlir::Type llvmShapeTy = convertType(shapeTy);
+    mlir::Type i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64);
+    mlir::Value structVal =
+        mlir::LLVM::UndefOp::create(rewriter, loc, llvmShapeTy);
+    for (auto [i, extent] : llvm::enumerate(adaptor.getExtents())) {
+      mlir::Value extentI64 =
+          integerCast(loc, rewriter, i64Ty, extent, /*fold=*/true);
+      structVal =
+          mlir::LLVM::InsertValueOp::create(rewriter, loc, structVal, extentI64, i);
+    }
+    rewriter.replaceOp(op, structVal);
+    return mlir::success();
+  }
+};
+
+struct ShapeExtentsOpConversion : public fir::FIROpConversion<fir::ShapeExtentsOp> {
+  using FIROpConversion::FIROpConversion;
+
+  llvm::LogicalResult
+  matchAndRewrite(fir::ShapeExtentsOp op, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto shapeTy = mlir::cast<fir::ShapeType>(op.getShape().getType());
+    if (shapeTy.getRank() != op.getNumResults())
+      return mlir::failure();
+    mlir::Type i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64);
+    mlir::Value llvmShape = adaptor.getShape();
+    llvm::SmallVector<mlir::Value> results;
+    for (unsigned i = 0; i < op.getNumResults(); ++i) {
+      mlir::Value extentI64 =
+          mlir::LLVM::ExtractValueOp::create(rewriter, loc, i64Ty, llvmShape, i);
+      mlir::Type resultTy = convertType(op.getExtents()[i].getType());
+      results.push_back(
+          integerCast(loc, rewriter, resultTy, extentI64, /*fold=*/true));
+    }
+    rewriter.replaceOp(op, results);
+    return mlir::success();
+  }
 };
 
 struct ShapeShiftOpConversion : public MustBeDeadConversion<fir::ShapeShiftOp> {
@@ -4880,14 +4929,14 @@ void fir::populateFIRToLLVMConversionPatterns(
       LogicalOrOpConversion, MulcOpConversion, NegcOpConversion,
       NeqvOpConversion, NoReassocOpConversion, PrefetchOpConversion,
       SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion,
-      SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion,
-      ShiftOpConversion, SliceOpConversion, StoreOpConversion,
-      StringLitOpConversion, SubcOpConversion, TypeDescOpConversion,
-      TypeInfoOpConversion, UnboxCharOpConversion, UnboxProcOpConversion,
-      UndefOpConversion, UnreachableOpConversion, UseStmtOpConversion,
-      ModuleDebugImportsOpConversion, XArrayCoorOpConversion,
-      XEmboxOpConversion, XReboxOpConversion, ZeroOpConversion>(converter,
-                                                                options);
+      SelectTypeOpConversion, ShapeOpConversion, ShapeExtentsOpConversion,
+      ShapeShiftOpConversion, ShiftOpConversion, SliceOpConversion,
+      StoreOpConversion, StringLitOpConversion, SubcOpConversion,
+      TypeDescOpConversion, TypeInfoOpConversion, UnboxCharOpConversion,
+      UnboxProcOpConversion, UndefOpConversion, UnreachableOpConversion,
+      UseStmtOpConversion, ModuleDebugImportsOpConversion,
+      XArrayCoorOpConversion, XEmboxOpConversion, XReboxOpConversion,
+      ZeroOpConversion>(converter, options);
 
   // Patterns that are populated without a type converter do not trigger
   // target materializations for the operands of the root op.
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index 2fca4111e0980..0cbf0aa259219 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -123,6 +123,12 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
   addConversion([&](fir::SequenceType sequence) {
     return convertSequenceType(sequence);
   });
+  addConversion([&](fir::ShapeType shape) {
+    mlir::Type i64Ty = mlir::IntegerType::get(&getContext(), 64);
+    llvm::SmallVector<mlir::Type> members(shape.getRank(), i64Ty);
+    return mlir::LLVM::LLVMStructType::getLiteral(&getContext(), members,
+                                                  /*isPacked=*/false);
+  });
   addConversion([&](fir::TypeDescType tdesc) {
     return convertTypeDescType(tdesc.getContext());
   });
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 937f5c3f07e7d..c53d1b20456c7 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4798,6 +4798,51 @@ void fir::ShapeOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
   build(builder, result, type, extents);
 }
 
+//===----------------------------------------------------------------------===//
+// ShapeExtentsOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct FoldShapeExtentsOfShape
+    : public mlir::OpRewritePattern<fir::ShapeExtentsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::ShapeExtentsOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    auto shapeOp = op.getShape().getDefiningOp<fir::ShapeOp>();
+    if (!shapeOp)
+      return mlir::failure();
+    rewriter.replaceOp(op, shapeOp.getExtents());
+    return mlir::success();
+  }
+};
+} // namespace
+
+llvm::LogicalResult fir::ShapeExtentsOp::verify() {
+  auto shapeTy = mlir::dyn_cast<fir::ShapeType>(getShape().getType());
+  if (!shapeTy)
+    return emitOpError("operand must be a !fir.shape type");
+  if (getNumResults() != shapeTy.getRank())
+    return emitOpError("number of results must match shape rank");
+  return mlir::success();
+}
+
+void fir::ShapeExtentsOp::build(mlir::OpBuilder &builder,
+                                mlir::OperationState &result,
+                                mlir::Value shape) {
+  auto shapeTy = mlir::cast<fir::ShapeType>(shape.getType());
+  mlir::Type indexTy = builder.getIndexType();
+  llvm::SmallVector<mlir::Type> resultTypes(shapeTy.getRank(), indexTy);
+  result.addTypes(resultTypes);
+  result.addOperands(shape);
+}
+
+void fir::ShapeExtentsOp::getCanonicalizationPatterns(
+    mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
+  patterns.add<FoldShapeExtentsOfShape>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeShiftOp
 //===----------------------------------------------------------------------===//
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index 2c27e34a1d5c2..3f2a049d8a89e 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -177,6 +177,13 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
 
   void populateShape(SmallVectorImpl<Value> &vec, fir::ShapeOp shape) const;
 
+  /// Recover per-dimension extent SSA values from a shape operand. Inserts
+  /// `fir.shape_extents` when the defining `fir.shape` is not visible (e.g.
+  /// block argument from control-flow merge).
+  bool materializeShapeExtents(Value shapeVal, PatternRewriter &rewriter,
+                               Location loc,
+                               SmallVectorImpl<Value> &shapeVec) const;
+
   static fir::SliceOp getSliceOp(Value sliceVal) {
     return sliceVal ? sliceVal.getDefiningOp<fir::SliceOp>() : fir::SliceOp{};
   }
@@ -316,6 +323,36 @@ void FIRToMemRef::populateShape(SmallVectorImpl<Value> &vec,
   vec.append(shape.getExtents().begin(), shape.getExtents().end());
 }
 
+bool FIRToMemRef::materializeShapeExtents(
+    Value shapeVal, PatternRewriter &rewriter, Location loc,
+    SmallVectorImpl<Value> &shapeVec) const {
+  if (!shapeVal)
+    return false;
+
+  while (auto convertOp = shapeVal.getDefiningOp<fir::ConvertOp>())
+    shapeVal = convertOp.getOperand();
+
+  if (auto shapeOp = shapeVal.getDefiningOp<fir::ShapeOp>()) {
+    shapeVec.append(shapeOp.getExtents().begin(), shapeOp.getExtents().end());
+    return true;
+  }
+
+  if (auto extentsOp = shapeVal.getDefiningOp<fir::ShapeExtentsOp>()) {
+    shapeVec.append(extentsOp.getExtents().begin(),
+                   extentsOp.getExtents().end());
+    return true;
+  }
+
+  if (mlir::isa<fir::ShapeType>(shapeVal.getType())) {
+    auto extentsOp = fir::ShapeExtentsOp::create(rewriter, loc, shapeVal);
+    shapeVec.append(extentsOp.getExtents().begin(),
+                   extentsOp.getExtents().end());
+    return true;
+  }
+
+  return false;
+}
+
 template <typename OpTy>
 void FIRToMemRef::collectSliceInfoFrom(OpTy op, SliceInfo &info) const {
   if constexpr (std::is_same_v<OpTy, fir::ArrayCoorOp> ||
@@ -324,14 +361,14 @@ void FIRToMemRef::collectSliceInfoFrom(OpTy op, SliceInfo &info) const {
     Value shapeVal = op.getShape();
 
     if (shapeVal) {
-      Operation *shapeValOp = shapeVal.getDefiningOp();
-
-      if (auto shapeOp = dyn_cast<fir::ShapeOp>(shapeValOp)) {
-        populateShape(info.shapeVec, shapeOp);
-      } else if (auto shapeShiftOp = dyn_cast<fir::ShapeShiftOp>(shapeValOp)) {
-        populateShapeAndShift(info.shapeVec, info.shiftVec, shapeShiftOp);
-      } else if (auto shiftOp = dyn_cast<fir::ShiftOp>(shapeValOp)) {
-        populateShift(info.shiftVec, shiftOp);
+      if (Operation *shapeValOp = shapeVal.getDefiningOp()) {
+        if (auto shapeOp = dyn_cast<fir::ShapeOp>(shapeValOp)) {
+          populateShape(info.shapeVec, shapeOp);
+        } else if (auto shapeShiftOp = dyn_cast<fir::ShapeShiftOp>(shapeValOp)) {
+          populateShapeAndShift(info.shapeVec, info.shiftVec, shapeShiftOp);
+        } else if (auto shiftOp = dyn_cast<fir::ShiftOp>(shapeValOp)) {
+          populateShift(info.shiftVec, shiftOp);
+        }
       }
     }
 
@@ -642,6 +679,20 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
     rewriter.setInsertionPointAfter(arrayCoorOp);
   }
 
+  SliceInfo sliceInfo;
+  collectSliceInfoFrom(arrayCoorOp, sliceInfo);
+  if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>())
+    collectSliceInfoFrom(embox, sliceInfo);
+  else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>())
+    collectSliceInfoFrom(rebox, sliceInfo);
+
+  if (!sliceInfo.hasProjectedSlice && sliceInfo.shapeVec.empty()) {
+    rewriter.setInsertionPoint(arrayCoorOp);
+    (void)materializeShapeExtents(arrayCoorOp.getShape(), rewriter, loc,
+                                  sliceInfo.shapeVec);
+    rewriter.setInsertionPointAfter(arrayCoorOp);
+  }
+
   Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
   FailureOr<SmallVector<Value>> failureOrIndices =
       getMemrefIndices(arrayCoorOp, memref, rewriter, *converted, one);
@@ -662,12 +713,6 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
   // For complex projections, reinterpret memref<d0×...×complex<T>> as
   // memref<d0×...×2×T> and append the component index (0=re, 1=im) so that
   // each load/store touches exactly sizeof(T) bytes.
-  SliceInfo sliceInfo;
-  collectSliceInfoFrom(arrayCoorOp, sliceInfo);
-  if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>())
-    collectSliceInfoFrom(embox, sliceInfo);
-  else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>())
-    collectSliceInfoFrom(rebox, sliceInfo);
   auto srcTy = cast<MemRefType>((*converted).getType());
   if (sliceInfo.hasProjectedSlice) {
     if (auto complexTy = dyn_cast<mlir::ComplexType>(srcTy.getElementType())) {
@@ -729,6 +774,11 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
       (isDescriptor && sliceInfo.shiftVec.empty() && arrayCoorOp.getShape() &&
        arrayCoorOp.getSlice());
   if (descriptorOwnsLayout) {
+    // Plain `!fir.ref` without recoverable shape extents cannot use fir.box_*.
+    if (shapeVec.empty() && !sliceInfo.hasProjectedSlice && !isDescriptor &&
+        !isRebox)
+      return failure();
+
     // Projected slices carry their physical layout in the descriptor. Rebuild
     // the MemRef view from box metadata instead of from slice triplets.
     auto boxElementSize =
diff --git a/flang/test/Fir/convert-to-llvm-invalid.fir b/flang/test/Fir/convert-to-llvm-invalid.fir
index 9f003e4eb7d59..bd22dc81f05ca 100644
--- a/flang/test/Fir/convert-to-llvm-invalid.fir
+++ b/flang/test/Fir/convert-to-llvm-invalid.fir
@@ -3,18 +3,6 @@
 // RUN: fir-opt --split-input-file --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" --verify-diagnostics %s
 
 
-// Test `fir.shape` conversion failure because the op has uses.
-
-func.func @shape_not_dead(%arg0: !fir.ref<!fir.array<?x?xf32>>, %i: index, %j: index) {
-  %c0 = arith.constant 1 : index
-  // expected-error at +1{{failed to legalize operation 'fir.shape'}}
-  %0 = fir.shape %c0, %c0 : (index, index) -> !fir.shape<2>
-  %1 = fir.array_coor %arg0(%0) %i, %j : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
-  return
-}
-
-// -----
-
 // Test `fir.slice` conversion failure because the op has uses.
 
 func.func @slice_not_dead(%arg0: !fir.ref<!fir.array<?x?xf32>>, %i: index, %j: index) {
diff --git a/flang/test/Fir/shape-extents.mlir b/flang/test/Fir/shape-extents.mlir
new file mode 100644
index 0000000000000..3d7c17bf2fcb6
--- /dev/null
+++ b/flang/test/Fir/shape-extents.mlir
@@ -0,0 +1,42 @@
+// RUN: fir-opt %s | FileCheck %s --check-prefix=PLAIN
+// RUN: fir-opt %s -canonicalize | FileCheck %s --check-prefix=CANON
+
+// PLAIN-LABEL: func @fold_shape_extents
+// PLAIN:         fir.shape_extents
+// CANON-LABEL:   func @fold_shape_extents
+// CANON:         fir.fake_use %[[N:arg[0-9]+]] : index
+func.func @fold_shape_extents(%n : index) {
+  %sh = fir.shape %n : (index) -> !fir.shape<1>
+  %e = fir.shape_extents %sh : (!fir.shape<1>) -> (index)
+  fir.fake_use %e : index
+  return
+}
+
+// PLAIN-LABEL: func @shape_extents_2d
+// PLAIN:         fir.shape_extents
+// CANON-LABEL:   func @shape_extents_2d
+// CANON:         fir.fake_use %[[N1:arg[0-9]+]], %[[N2:arg[0-9]+]] : index, index
+func.func @shape_extents_2d(%n1 : index, %n2 : index) {
+  %sh = fir.shape %n1, %n2 : (index, index) -> !fir.shape<2>
+  %e0, %e1 = fir.shape_extents %sh : (!fir.shape<2>) -> (index, index)
+  fir.fake_use %e0, %e1 : index, index
+  return
+}
+
+// PLAIN-LABEL: func @shape_extents_block_arg
+// PLAIN:         fir.shape_extents
+// CANON-LABEL:   func @shape_extents_block_arg
+// CANON:         fir.shape_extents
+func.func @shape_extents_block_arg(%pred : i1, %n1 : index, %n2 : index) {
+  cf.cond_br %pred, ^bb1, ^bb2
+^bb1:
+  %sh1 = fir.shape %n1 : (index) -> !fir.shape<1>
+  cf.br ^bb3(%sh1 : !fir.shape<1>)
+^bb2:
+  %sh2 = fir.shape %n2 : (index) -> !fir.shape<1>
+  cf.br ^bb3(%sh2 : !fir.shape<1>)
+^bb3(%phi : !fir.shape<1>):
+  %e = fir.shape_extents %phi : (!fir.shape<1>) -> (index)
+  fir.fake_use %e : index
+  return
+}

>From e1419420a5e115760315d772a5974277e11074f2 Mon Sep 17 00:00:00 2001
From: Yebin Chon <ychon at nvidia.com>
Date: Sat, 23 May 2026 09:46:14 -0700
Subject: [PATCH 2/6] add forwarded-shape test

---
 .../FIRToMemRef/forwarded-shape.mlir          | 24 +++++++++++++++++++
 1 file changed, 24 insertions(+)
 create mode 100644 flang/test/Transforms/FIRToMemRef/forwarded-shape.mlir

diff --git a/flang/test/Transforms/FIRToMemRef/forwarded-shape.mlir b/flang/test/Transforms/FIRToMemRef/forwarded-shape.mlir
new file mode 100644
index 0000000000000..b3e8cae4e00a3
--- /dev/null
+++ b/flang/test/Transforms/FIRToMemRef/forwarded-shape.mlir
@@ -0,0 +1,24 @@
+// RUN: fir-opt %s --fir-to-memref --allow-unregistered-dialect | FileCheck %s
+
+func.func @forwarded_shape_store(%pred : i1, %n1 : index, %n2 : index,
+    %arg0: !fir.ref<!fir.array<?xi32>>) {
+  cf.cond_br %pred, ^bb1, ^bb2
+^bb1:
+  %sh1 = fir.shape %n1 : (index) -> !fir.shape<1>
+  cf.br ^bb3(%sh1 : !fir.shape<1>)
+^bb2:
+  %sh2 = fir.shape %n2 : (index) -> !fir.shape<1>
+  cf.br ^bb3(%sh2 : !fir.shape<1>)
+^bb3(%phi : !fir.shape<1>):
+  %c1 = arith.constant 1 : index
+  %c42 = arith.constant 42 : i32
+  %elt = fir.array_coor %arg0(%phi) %c1
+      : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+  fir.store %c42 to %elt : !fir.ref<i32>
+  return
+}
+
+// CHECK-LABEL: func.func @forwarded_shape_store
+// CHECK:       fir.shape_extents
+// CHECK:       memref.store
+// CHECK-NOT:   fir.array_coor

>From f93e2a15fc5abe3269497aa5f215d7f198c09496 Mon Sep 17 00:00:00 2001
From: Yebin Chon <ychon at nvidia.com>
Date: Sat, 23 May 2026 10:16:39 -0700
Subject: [PATCH 3/6] formatting

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp        | 11 ++++++-----
 flang/lib/Optimizer/Transforms/FIRToMemRef.cpp |  7 ++++---
 2 files changed, 10 insertions(+), 8 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 523ca462dde78..f71455aada159 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -4597,15 +4597,16 @@ struct ShapeOpConversion : public fir::FIROpConversion<fir::ShapeOp> {
     for (auto [i, extent] : llvm::enumerate(adaptor.getExtents())) {
       mlir::Value extentI64 =
           integerCast(loc, rewriter, i64Ty, extent, /*fold=*/true);
-      structVal =
-          mlir::LLVM::InsertValueOp::create(rewriter, loc, structVal, extentI64, i);
+      structVal = mlir::LLVM::InsertValueOp::create(rewriter, loc, structVal,
+                                                    extentI64, i);
     }
     rewriter.replaceOp(op, structVal);
     return mlir::success();
   }
 };
 
-struct ShapeExtentsOpConversion : public fir::FIROpConversion<fir::ShapeExtentsOp> {
+struct ShapeExtentsOpConversion
+    : public fir::FIROpConversion<fir::ShapeExtentsOp> {
   using FIROpConversion::FIROpConversion;
 
   llvm::LogicalResult
@@ -4619,8 +4620,8 @@ struct ShapeExtentsOpConversion : public fir::FIROpConversion<fir::ShapeExtentsO
     mlir::Value llvmShape = adaptor.getShape();
     llvm::SmallVector<mlir::Value> results;
     for (unsigned i = 0; i < op.getNumResults(); ++i) {
-      mlir::Value extentI64 =
-          mlir::LLVM::ExtractValueOp::create(rewriter, loc, i64Ty, llvmShape, i);
+      mlir::Value extentI64 = mlir::LLVM::ExtractValueOp::create(
+          rewriter, loc, i64Ty, llvmShape, i);
       mlir::Type resultTy = convertType(op.getExtents()[i].getType());
       results.push_back(
           integerCast(loc, rewriter, resultTy, extentI64, /*fold=*/true));
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index 3f2a049d8a89e..c33a2103ab9e9 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -339,14 +339,14 @@ bool FIRToMemRef::materializeShapeExtents(
 
   if (auto extentsOp = shapeVal.getDefiningOp<fir::ShapeExtentsOp>()) {
     shapeVec.append(extentsOp.getExtents().begin(),
-                   extentsOp.getExtents().end());
+                    extentsOp.getExtents().end());
     return true;
   }
 
   if (mlir::isa<fir::ShapeType>(shapeVal.getType())) {
     auto extentsOp = fir::ShapeExtentsOp::create(rewriter, loc, shapeVal);
     shapeVec.append(extentsOp.getExtents().begin(),
-                   extentsOp.getExtents().end());
+                    extentsOp.getExtents().end());
     return true;
   }
 
@@ -364,7 +364,8 @@ void FIRToMemRef::collectSliceInfoFrom(OpTy op, SliceInfo &info) const {
       if (Operation *shapeValOp = shapeVal.getDefiningOp()) {
         if (auto shapeOp = dyn_cast<fir::ShapeOp>(shapeValOp)) {
           populateShape(info.shapeVec, shapeOp);
-        } else if (auto shapeShiftOp = dyn_cast<fir::ShapeShiftOp>(shapeValOp)) {
+        } else if (auto shapeShiftOp =
+                       dyn_cast<fir::ShapeShiftOp>(shapeValOp)) {
           populateShapeAndShift(info.shapeVec, info.shiftVec, shapeShiftOp);
         } else if (auto shiftOp = dyn_cast<fir::ShiftOp>(shapeValOp)) {
           populateShift(info.shiftVec, shiftOp);

>From ca5285dc5ed70c4803087f52d716683fe67c6c58 Mon Sep 17 00:00:00 2001
From: Yebin Chon <ychon at nvidia.com>
Date: Tue, 2 Jun 2026 10:08:09 -0700
Subject: [PATCH 4/6] insert type casts for fir.shape and fir.shape_extents
 type mismatch. Add failure point when shape cannot be recovered

---
 flang/lib/Optimizer/Dialect/FIROps.cpp         | 13 ++++++++++++-
 flang/lib/Optimizer/Transforms/FIRToMemRef.cpp | 15 ++++++++-------
 flang/test/Fir/shape-extents.mlir              | 15 +++++++++++++++
 3 files changed, 35 insertions(+), 8 deletions(-)

diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index c53d1b20456c7..fa48bfcc39cb4 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4813,7 +4813,18 @@ struct FoldShapeExtentsOfShape
     auto shapeOp = op.getShape().getDefiningOp<fir::ShapeOp>();
     if (!shapeOp)
       return mlir::failure();
-    rewriter.replaceOp(op, shapeOp.getExtents());
+
+    llvm::SmallVector<mlir::Value> results;
+    for (auto [extent, resultType] : llvm::zip(shapeOp.getExtents(), op.getResultTypes())) {
+      if (extent.getType() == resultType) {
+        results.push_back(extent);
+      } else if (fir::ConvertOp::canBeConverted(extent.getType(), resultType)) {
+        results.push_back(fir::ConvertOp::create(rewriter, op.getLoc(), resultType, extent));
+      } else {
+        return mlir::failure();
+      }
+    }
+    rewriter.replaceOp(op, results);
     return mlir::success();
   }
 };
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index c33a2103ab9e9..9d0199dbc5920 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -329,9 +329,6 @@ bool FIRToMemRef::materializeShapeExtents(
   if (!shapeVal)
     return false;
 
-  while (auto convertOp = shapeVal.getDefiningOp<fir::ConvertOp>())
-    shapeVal = convertOp.getOperand();
-
   if (auto shapeOp = shapeVal.getDefiningOp<fir::ShapeOp>()) {
     shapeVec.append(shapeOp.getExtents().begin(), shapeOp.getExtents().end());
     return true;
@@ -688,10 +685,14 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
     collectSliceInfoFrom(rebox, sliceInfo);
 
   if (!sliceInfo.hasProjectedSlice && sliceInfo.shapeVec.empty()) {
-    rewriter.setInsertionPoint(arrayCoorOp);
-    (void)materializeShapeExtents(arrayCoorOp.getShape(), rewriter, loc,
-                                  sliceInfo.shapeVec);
-    rewriter.setInsertionPointAfter(arrayCoorOp);
+    auto shapeVal = arrayCoorOp.getShape();
+    if (shapeVal && mlir::isa<fir::ShapeType>(shapeVal.getType())) {
+      rewriter.setInsertionPoint(arrayCoorOp);
+      if(!materializeShapeExtents(shapeVal, rewriter, loc,
+                                    sliceInfo.shapeVec))
+                                      return failure();
+      rewriter.setInsertionPointAfter(arrayCoorOp);
+    }
   }
 
   Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
diff --git a/flang/test/Fir/shape-extents.mlir b/flang/test/Fir/shape-extents.mlir
index 3d7c17bf2fcb6..d35e2982ecaad 100644
--- a/flang/test/Fir/shape-extents.mlir
+++ b/flang/test/Fir/shape-extents.mlir
@@ -40,3 +40,18 @@ func.func @shape_extents_block_arg(%pred : i1, %n1 : index, %n2 : index) {
   fir.fake_use %e : index
   return
 }
+
+// Check for proper insertion of casting when types of 
+// fir.shape ops and fir.shape_extents results do not match
+// PLAIN-LABEL: func @fold_shape_extents_cast
+// PLAIN:         fir.shape_extents
+// CANON-LABEL:   func @fold_shape_extents_cast
+// CANON-NOT:     fir.shape_extents
+// CANON:         fir.convert %{{.*}} : (i64) -> index
+// CANON:         fir.fake_use %{{.*}} : index
+func.func @fold_shape_extents_cast(%e : i64) {
+  %sh = fir.shape %e : (i64) -> !fir.shape<1>
+  %ext = fir.shape_extents %sh : (!fir.shape<1>) -> (index)
+  fir.fake_use %ext : index
+  return
+}
\ No newline at end of file

>From af185eb62891d27a9ada36168dda15a0cec8d092 Mon Sep 17 00:00:00 2001
From: Yebin Chon <ychon at nvidia.com>
Date: Thu, 4 Jun 2026 12:35:17 -0700
Subject: [PATCH 5/6] also takes fir.shape_shift op

---
 .../include/flang/Optimizer/Dialect/FIROps.td |  6 +--
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       | 44 +++++++++++++++++--
 flang/lib/Optimizer/CodeGen/TypeConverter.cpp |  6 +++
 flang/lib/Optimizer/Dialect/FIROps.cpp        | 38 +++++++++++-----
 .../lib/Optimizer/Transforms/FIRToMemRef.cpp  | 13 ++++--
 5 files changed, 85 insertions(+), 22 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 3d0ac4398e7df..74e7937550b8f 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2088,10 +2088,10 @@ def fir_ShapeOp : fir_Op<"shape", [Pure]> {
 }
 
 def fir_ShapeExtentsOp : fir_Op<"shape_extents", [Pure]> {
-  let summary = "unpack a `!fir.shape` into per-dimension extent SSA values";
+  let summary = "unpack a shape value into per-dimension extent SSA values";
 
   let description = [{
-    Takes a single abstract `!fir.shape<n>` value and yields `n` integer SSA
+    Takes a single abstract `!fir.shape<n>` or `!fir.shapeshift<n>` value and yields `n` integer SSA
     results, one per dimension extent in Fortran row-to-column order. This is
     intended for lowering when extent values are needed but the defining
     `fir.shape` is not visible (for example when a shape is forwarded through
@@ -2105,7 +2105,7 @@ def fir_ShapeExtentsOp : fir_Op<"shape_extents", [Pure]> {
     ```
   }];
 
-  let arguments = (ins fir_ShapeType:$shape);
+  let arguments = (ins AnyShapeType:$shape);
 
   let results = (outs Variadic<AnyIntegerType>:$extents);
 
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index f71455aada159..875a264ebbd03 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -4613,9 +4613,18 @@ struct ShapeExtentsOpConversion
   matchAndRewrite(fir::ShapeExtentsOp op, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
-    auto shapeTy = mlir::cast<fir::ShapeType>(op.getShape().getType());
-    if (shapeTy.getRank() != op.getNumResults())
+
+    mlir::Type ty = op.getShape().getType();
+    unsigned rank;
+    if (auto shapeTy = mlir::dyn_cast<fir::ShapeType>(ty))
+      rank = shapeTy.getRank();
+    else if (auto ssTy = mlir::dyn_cast<fir::ShapeShiftType>(ty))
+      rank = ssTy.getRank();
+    else
+      return mlir::failure();
+    if (rank != op.getNumResults())
       return mlir::failure();
+
     mlir::Type i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64);
     mlir::Value llvmShape = adaptor.getShape();
     llvm::SmallVector<mlir::Value> results;
@@ -4631,8 +4640,35 @@ struct ShapeExtentsOpConversion
   }
 };
 
-struct ShapeShiftOpConversion : public MustBeDeadConversion<fir::ShapeShiftOp> {
-  using MustBeDeadConversion::MustBeDeadConversion;
+struct ShapeShiftOpConversion : public fir::FIROpConversion<fir::ShapeShiftOp> {
+  using FIROpConversion::FIROpConversion;
+
+  llvm::LogicalResult
+  matchAndRewrite(fir::ShapeShiftOp op, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    if (op->use_empty()) {
+      rewriter.eraseOp(op);
+      return mlir::success();
+    }
+    auto loc = op.getLoc();
+    auto ssTy = mlir::cast<fir::ShapeShiftType>(op.getType());
+    mlir::Type llvmTy = convertType(ssTy);
+    mlir::Type i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64);
+    mlir::Value structVal =
+        mlir::LLVM::UndefOp::create(rewriter, loc, llvmTy);
+    // Pack extent operands only; lower bounds are not part of the LLVM shape
+    // bundle consumed by fir.shape_extents.
+    for (auto [i, extent] : llvm::enumerate(adaptor.getPairs())) {
+      if (i & 1)
+        continue;
+      mlir::Value extentI64 =
+          integerCast(loc, rewriter, i64Ty, extent, /*fold=*/true);
+      structVal = mlir::LLVM::InsertValueOp::create(rewriter, loc, structVal,
+                                                    extentI64, i / 2);
+    }
+    rewriter.replaceOp(op, structVal);
+    return mlir::success();
+  }
 };
 
 struct ShiftOpConversion : public MustBeDeadConversion<fir::ShiftOp> {
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index 0cbf0aa259219..7241839661d02 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -129,6 +129,12 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
     return mlir::LLVM::LLVMStructType::getLiteral(&getContext(), members,
                                                   /*isPacked=*/false);
   });
+  addConversion([&](fir::ShapeShiftType shapeShift) {
+    mlir::Type i64Ty = mlir::IntegerType::get(&getContext(), 64);
+    llvm::SmallVector<mlir::Type> members(shapeShift.getRank(), i64Ty);
+    return mlir::LLVM::LLVMStructType::getLiteral(&getContext(), members,
+                                                  /*isPacked=*/false);
+  });
   addConversion([&](fir::TypeDescType tdesc) {
     return convertTypeDescType(tdesc.getContext());
   });
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index fa48bfcc39cb4..0907c8e862db3 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4810,16 +4810,22 @@ struct FoldShapeExtentsOfShape
   mlir::LogicalResult
   matchAndRewrite(fir::ShapeExtentsOp op,
                   mlir::PatternRewriter &rewriter) const override {
-    auto shapeOp = op.getShape().getDefiningOp<fir::ShapeOp>();
-    if (!shapeOp)
+    mlir::Value shape = op.getShape();
+    mlir::ValueRange extents;
+    if (auto shapeOp = shape.getDefiningOp<fir::ShapeOp>())
+      extents = shapeOp.getExtents();
+    else if (auto ssOp = shape.getDefiningOp<fir::ShapeShiftOp>())
+      extents = ssOp.getExtents();
+    else
       return mlir::failure();
-
     llvm::SmallVector<mlir::Value> results;
-    for (auto [extent, resultType] : llvm::zip(shapeOp.getExtents(), op.getResultTypes())) {
+    for (auto [extent, resultType] :
+         llvm::zip(extents, op.getResultTypes())) {
       if (extent.getType() == resultType) {
         results.push_back(extent);
       } else if (fir::ConvertOp::canBeConverted(extent.getType(), resultType)) {
-        results.push_back(fir::ConvertOp::create(rewriter, op.getLoc(), resultType, extent));
+        results.push_back(
+            fir::ConvertOp::create(rewriter, op.getLoc(), resultType, extent));
       } else {
         return mlir::failure();
       }
@@ -4831,10 +4837,15 @@ struct FoldShapeExtentsOfShape
 } // namespace
 
 llvm::LogicalResult fir::ShapeExtentsOp::verify() {
-  auto shapeTy = mlir::dyn_cast<fir::ShapeType>(getShape().getType());
-  if (!shapeTy)
-    return emitOpError("operand must be a !fir.shape type");
-  if (getNumResults() != shapeTy.getRank())
+  mlir::Type ty = getShape().getType();
+  unsigned rank;
+  if (auto shapeTy = mlir::dyn_cast<fir::ShapeType>(ty))
+    rank = shapeTy.getRank();
+  else if (auto ssTy = mlir::dyn_cast<fir::ShapeShiftType>(ty))
+    rank = ssTy.getRank();
+  else
+    return emitOpError("operand must be !fir.shape or !fir.shapeshift");
+  if (getNumResults() != rank)
     return emitOpError("number of results must match shape rank");
   return mlir::success();
 }
@@ -4842,9 +4853,14 @@ llvm::LogicalResult fir::ShapeExtentsOp::verify() {
 void fir::ShapeExtentsOp::build(mlir::OpBuilder &builder,
                                 mlir::OperationState &result,
                                 mlir::Value shape) {
-  auto shapeTy = mlir::cast<fir::ShapeType>(shape.getType());
+  mlir::Type ty = shape.getType();
+  unsigned rank;
+  if (auto shapeTy = mlir::dyn_cast<fir::ShapeType>(ty))
+    rank = shapeTy.getRank();
+  else
+    rank = mlir::cast<fir::ShapeShiftType>(ty).getRank();
   mlir::Type indexTy = builder.getIndexType();
-  llvm::SmallVector<mlir::Type> resultTypes(shapeTy.getRank(), indexTy);
+  llvm::SmallVector<mlir::Type> resultTypes(rank, indexTy);
   result.addTypes(resultTypes);
   result.addOperands(shape);
 }
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index 9d0199dbc5920..040196562c503 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -178,8 +178,8 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
   void populateShape(SmallVectorImpl<Value> &vec, fir::ShapeOp shape) const;
 
   /// Recover per-dimension extent SSA values from a shape operand. Inserts
-  /// `fir.shape_extents` when the defining `fir.shape` is not visible (e.g.
-  /// block argument from control-flow merge).
+  /// `fir.shape_extents` when the defining `fir.shape` or `fir.shapeshift` is
+  /// not visible (e.g. block argument from control-flow merge).
   bool materializeShapeExtents(Value shapeVal, PatternRewriter &rewriter,
                                Location loc,
                                SmallVectorImpl<Value> &shapeVec) const;
@@ -334,13 +334,18 @@ bool FIRToMemRef::materializeShapeExtents(
     return true;
   }
 
+  if (auto ssOp = shapeVal.getDefiningOp<fir::ShapeShiftOp>()) {
+    shapeVec.append(ssOp.getExtents().begin(), ssOp.getExtents().end());
+    return true;
+  }
+
   if (auto extentsOp = shapeVal.getDefiningOp<fir::ShapeExtentsOp>()) {
     shapeVec.append(extentsOp.getExtents().begin(),
                     extentsOp.getExtents().end());
     return true;
   }
 
-  if (mlir::isa<fir::ShapeType>(shapeVal.getType())) {
+  if (mlir::isa<fir::ShapeType, fir::ShapeShiftType>(shapeVal.getType())) {
     auto extentsOp = fir::ShapeExtentsOp::create(rewriter, loc, shapeVal);
     shapeVec.append(extentsOp.getExtents().begin(),
                     extentsOp.getExtents().end());
@@ -686,7 +691,7 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
 
   if (!sliceInfo.hasProjectedSlice && sliceInfo.shapeVec.empty()) {
     auto shapeVal = arrayCoorOp.getShape();
-    if (shapeVal && mlir::isa<fir::ShapeType>(shapeVal.getType())) {
+    if (shapeVal && mlir::isa<fir::ShapeType, fir::ShapeShiftType>(shapeVal.getType())) {
       rewriter.setInsertionPoint(arrayCoorOp);
       if(!materializeShapeExtents(shapeVal, rewriter, loc,
                                     sliceInfo.shapeVec))

>From 0cd7187dc56c944cfe93c8de87c29d78f052eeb4 Mon Sep 17 00:00:00 2001
From: Yebin Chon <ychon at nvidia.com>
Date: Mon, 8 Jun 2026 09:47:06 -0700
Subject: [PATCH 6/6] add CodeGen tests. remove fir.shape_extents from invalid
 checking test

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       |  6 +-
 .../lib/Optimizer/Transforms/FIRToMemRef.cpp  |  6 --
 flang/test/Fir/convert-to-llvm-invalid.fir    | 12 ---
 flang/test/Fir/shape-to-llvm.mlir             | 89 +++++++++++++++++++
 4 files changed, 92 insertions(+), 21 deletions(-)
 create mode 100644 flang/test/Fir/shape-to-llvm.mlir

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 875a264ebbd03..21050a8fd1a5e 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -4658,11 +4658,11 @@ struct ShapeShiftOpConversion : public fir::FIROpConversion<fir::ShapeShiftOp> {
         mlir::LLVM::UndefOp::create(rewriter, loc, llvmTy);
     // Pack extent operands only; lower bounds are not part of the LLVM shape
     // bundle consumed by fir.shape_extents.
-    for (auto [i, extent] : llvm::enumerate(adaptor.getPairs())) {
-      if (i & 1)
+    for (auto [i, pair] : llvm::enumerate(adaptor.getPairs())) {
+      if (!(i & 1))
         continue;
       mlir::Value extentI64 =
-          integerCast(loc, rewriter, i64Ty, extent, /*fold=*/true);
+          integerCast(loc, rewriter, i64Ty, pair, /*fold=*/true);
       structVal = mlir::LLVM::InsertValueOp::create(rewriter, loc, structVal,
                                                     extentI64, i / 2);
     }
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index 040196562c503..be035095fc0a6 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -339,12 +339,6 @@ bool FIRToMemRef::materializeShapeExtents(
     return true;
   }
 
-  if (auto extentsOp = shapeVal.getDefiningOp<fir::ShapeExtentsOp>()) {
-    shapeVec.append(extentsOp.getExtents().begin(),
-                    extentsOp.getExtents().end());
-    return true;
-  }
-
   if (mlir::isa<fir::ShapeType, fir::ShapeShiftType>(shapeVal.getType())) {
     auto extentsOp = fir::ShapeExtentsOp::create(rewriter, loc, shapeVal);
     shapeVec.append(extentsOp.getExtents().begin(),
diff --git a/flang/test/Fir/convert-to-llvm-invalid.fir b/flang/test/Fir/convert-to-llvm-invalid.fir
index bd22dc81f05ca..b0c66e283bf5a 100644
--- a/flang/test/Fir/convert-to-llvm-invalid.fir
+++ b/flang/test/Fir/convert-to-llvm-invalid.fir
@@ -27,18 +27,6 @@ func.func @shift_not_dead(%arg0: !fir.box<!fir.array<?xf32>>, %i: index) {
 
 // -----
 
-// Test `fir.shape_shift` conversion failure because the op has uses.
-
-func.func @shape_shift_not_dead(%arg0: !fir.ref<!fir.array<?x?xf32>>, %i: index, %j: index) {
-  %c0 = arith.constant 1 : index
-  // expected-error at +1{{failed to legalize operation 'fir.shape_shift'}}
-  %0 = fir.shape_shift %c0, %c0, %c0, %c0 : (index, index, index, index) -> !fir.shapeshift<2>
-  %1 = fir.array_coor %arg0(%0) %i, %j : (!fir.ref<!fir.array<?x?xf32>>, !fir.shapeshift<2>, index, index) -> !fir.ref<f32>
-  return
-}
-
-// -----
-
 // Test `fir.select_type` conversion to llvm.
 // Should have been converted.
 
diff --git a/flang/test/Fir/shape-to-llvm.mlir b/flang/test/Fir/shape-to-llvm.mlir
new file mode 100644
index 0000000000000..91df929a39a4e
--- /dev/null
+++ b/flang/test/Fir/shape-to-llvm.mlir
@@ -0,0 +1,89 @@
+// RUN: fir-opt %s --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %s
+
+// ShapeOpConversion: live fir.shape lowers to an n-field LLVM struct.
+// CHECK-LABEL: llvm.func @live_shape(
+// CHECK-SAME: %[[N:.+]]: i64
+// CHECK: %[[UNDEF:.+]] = llvm.mlir.undef
+// CHECK: %[[STRUCT:.+]] = llvm.insertvalue %[[N]], %[[UNDEF]][0]
+// CHECK: llvm.extractvalue %[[STRUCT]][0]
+// CHECK-NOT: fir.shape
+// CHECK-NOT: fir.shape_extents
+// CHECK: llvm.return
+func.func @live_shape(%n : index) {
+  %sh = fir.shape %n : (index) -> !fir.shape<1>
+  %e = fir.shape_extents %sh : (!fir.shape<1>) -> index
+  %c0 = arith.constant 0 : index
+  %sink = arith.addi %e, %c0 : index
+  return
+}
+
+// -----
+
+// ShapeShiftOpConversion: packs extent operands only (n-field struct, not 2n).
+// ShapeExtentsOpConversion: extractvalue field 0 for rank-1 shapeshift.
+// CHECK-LABEL: llvm.func @live_shape_shift(
+// CHECK-SAME: %[[LB:.+]]: i64, %[[EXT:.+]]: i64
+// CHECK: %[[UNDEF:.+]] = llvm.mlir.undef
+// CHECK: %[[STRUCT:.+]] = llvm.insertvalue %[[EXT]], %[[UNDEF]][0]
+// CHECK-NOT: llvm.insertvalue %[[LB]]
+// CHECK: llvm.extractvalue %[[STRUCT]][0]
+// CHECK-NOT: fir.shape_shift
+// CHECK-NOT: fir.shape_extents
+// CHECK: llvm.return
+func.func @live_shape_shift(%lb : index, %ext : index) {
+  %ss = fir.shape_shift %lb, %ext : (index, index) -> !fir.shapeshift<1>
+  %e = fir.shape_extents %ss : (!fir.shapeshift<1>) -> index
+  %c0 = arith.constant 0 : index
+  %sink = arith.addi %e, %c0 : index
+  return
+}
+
+// -----
+
+// ShapeExtentsOpConversion on a forwarded !fir.shape block argument.
+// CHECK-LABEL: llvm.func @live_shape_extents_forwarded(
+// CHECK-SAME: %[[PRED:.+]]: i1, %[[N1:.+]]: i64, %[[N2:.+]]: i64
+// CHECK: llvm.cond_br %[[PRED]]
+// CHECK: %[[UNDEF1:.+]] = llvm.mlir.undef
+// CHECK: %[[S1:.+]] = llvm.insertvalue %[[N1]], %[[UNDEF1]][0]
+// CHECK: llvm.br {{.*}}(%[[S1]]
+// CHECK: %[[UNDEF2:.+]] = llvm.mlir.undef
+// CHECK: %[[S2:.+]] = llvm.insertvalue %[[N2]], %[[UNDEF2]][0]
+// CHECK: llvm.br {{.*}}(%[[S2]]
+// CHECK: llvm.extractvalue %{{.*}}[0]
+// CHECK-NOT: fir.shape
+// CHECK-NOT: fir.shape_extents
+// CHECK: llvm.return
+func.func @live_shape_extents_forwarded(%pred : i1, %n1 : index, %n2 : index) {
+  cf.cond_br %pred, ^bb1, ^bb2
+^bb1:
+  %sh1 = fir.shape %n1 : (index) -> !fir.shape<1>
+  cf.br ^bb3(%sh1 : !fir.shape<1>)
+^bb2:
+  %sh2 = fir.shape %n2 : (index) -> !fir.shape<1>
+  cf.br ^bb3(%sh2 : !fir.shape<1>)
+^bb3(%phi : !fir.shape<1>):
+  %e = fir.shape_extents %phi : (!fir.shape<1>) -> index
+  %c0 = arith.constant 0 : index
+  %sink = arith.addi %e, %c0 : index
+  return
+}
+
+// -----
+
+// 2-D shape: struct has two extent fields; shape_extents extracts [0] and [1].
+// CHECK-LABEL: llvm.func @live_shape_extents_2d(
+// CHECK: llvm.insertvalue {{.+}}[0]
+// CHECK: llvm.insertvalue {{.+}}[1]
+// CHECK: llvm.extractvalue {{.+}}[0]
+// CHECK: llvm.extractvalue {{.+}}[1]
+// CHECK-NOT: fir.shape
+// CHECK-NOT: fir.shape_extents
+func.func @live_shape_extents_2d(%n1 : index, %n2 : index) {
+  %sh = fir.shape %n1, %n2 : (index, index) -> !fir.shape<2>
+  %e0, %e1 = fir.shape_extents %sh : (!fir.shape<2>) -> (index, index)
+  %c0 = arith.constant 0 : index
+  %s0 = arith.addi %e0, %c0 : index
+  %s1 = arith.addi %e1, %c0 : index
+  return
+}



More information about the flang-commits mailing list