[flang-commits] [flang] [FIR] add a fir.shape_extents operation (PR #199361)
via flang-commits
flang-commits at lists.llvm.org
Tue Jun 9 09:16:18 PDT 2026
https://github.com/yebinchon updated https://github.com/llvm/llvm-project/pull/199361
>From a89731080414d08fb317344fdf6becc950d094bf Mon Sep 17 00:00:00 2001
From: Yebin Chon <ychon at nvidia.com>
Date: Sat, 23 May 2026 09:32:01 -0700
Subject: [PATCH 1/6] [FIR] add a shape_extents operation that takes a
fir.shape and unpacks it into n SSA values (one per dim)
---
.../include/flang/Optimizer/Dialect/FIROps.td | 35 +++++++++
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 69 +++++++++++++---
flang/lib/Optimizer/CodeGen/TypeConverter.cpp | 6 ++
flang/lib/Optimizer/Dialect/FIROps.cpp | 45 +++++++++++
.../lib/Optimizer/Transforms/FIRToMemRef.cpp | 78 +++++++++++++++----
flang/test/Fir/convert-to-llvm-invalid.fir | 12 ---
flang/test/Fir/shape-extents.mlir | 42 ++++++++++
7 files changed, 251 insertions(+), 36 deletions(-)
create mode 100644 flang/test/Fir/shape-extents.mlir
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index ec42855761bcb..3d0ac4398e7df 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2087,6 +2087,41 @@ def fir_ShapeOp : fir_Op<"shape", [Pure]> {
let builders = [OpBuilder<(ins "mlir::ValueRange":$extents)>];
}
+def fir_ShapeExtentsOp : fir_Op<"shape_extents", [Pure]> {
+ let summary = "unpack a `!fir.shape` into per-dimension extent SSA values";
+
+ let description = [{
+ Takes a single abstract `!fir.shape<n>` value and yields `n` integer SSA
+ results, one per dimension extent in Fortran row-to-column order. This is
+ intended for lowering when extent values are needed but the defining
+ `fir.shape` is not visible (for example when a shape is forwarded through
+ control flow as a block argument).
+
+ When the operand is the result of `fir.shape`, a canonicalization may fold
+ this operation to the original extent operands.
+
+ ```
+ %e0, %e1 = fir.shape_extents %sh : (!fir.shape<2>) -> (index, index)
+ ```
+ }];
+
+ let arguments = (ins fir_ShapeType:$shape);
+
+ let results = (outs Variadic<AnyIntegerType>:$extents);
+
+ let assemblyFormat = [{
+ $shape attr-dict `:` functional-type(operands, results)
+ }];
+
+ let hasVerifier = 1;
+
+ let hasCanonicalizer = 1;
+
+ let skipDefaultBuilders = 1;
+
+ let builders = [OpBuilder<(ins "mlir::Value":$shape)>];
+}
+
def fir_ShapeShiftOp : fir_Op<"shape_shift", [Pure]> {
let summary = [{
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 6b1acba393170..523ca462dde78 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -4577,8 +4577,57 @@ struct MustBeDeadConversion : public fir::FIROpConversion<FromOp> {
}
};
-struct ShapeOpConversion : public MustBeDeadConversion<fir::ShapeOp> {
- using MustBeDeadConversion::MustBeDeadConversion;
+// Shape can now be lowered into an llvm struct
+struct ShapeOpConversion : public fir::FIROpConversion<fir::ShapeOp> {
+ using FIROpConversion::FIROpConversion;
+
+ llvm::LogicalResult
+ matchAndRewrite(fir::ShapeOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ if (op->use_empty()) {
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+ auto loc = op.getLoc();
+ auto shapeTy = mlir::cast<fir::ShapeType>(op.getType());
+ mlir::Type llvmShapeTy = convertType(shapeTy);
+ mlir::Type i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64);
+ mlir::Value structVal =
+ mlir::LLVM::UndefOp::create(rewriter, loc, llvmShapeTy);
+ for (auto [i, extent] : llvm::enumerate(adaptor.getExtents())) {
+ mlir::Value extentI64 =
+ integerCast(loc, rewriter, i64Ty, extent, /*fold=*/true);
+ structVal =
+ mlir::LLVM::InsertValueOp::create(rewriter, loc, structVal, extentI64, i);
+ }
+ rewriter.replaceOp(op, structVal);
+ return mlir::success();
+ }
+};
+
+struct ShapeExtentsOpConversion : public fir::FIROpConversion<fir::ShapeExtentsOp> {
+ using FIROpConversion::FIROpConversion;
+
+ llvm::LogicalResult
+ matchAndRewrite(fir::ShapeExtentsOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto shapeTy = mlir::cast<fir::ShapeType>(op.getShape().getType());
+ if (shapeTy.getRank() != op.getNumResults())
+ return mlir::failure();
+ mlir::Type i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64);
+ mlir::Value llvmShape = adaptor.getShape();
+ llvm::SmallVector<mlir::Value> results;
+ for (unsigned i = 0; i < op.getNumResults(); ++i) {
+ mlir::Value extentI64 =
+ mlir::LLVM::ExtractValueOp::create(rewriter, loc, i64Ty, llvmShape, i);
+ mlir::Type resultTy = convertType(op.getExtents()[i].getType());
+ results.push_back(
+ integerCast(loc, rewriter, resultTy, extentI64, /*fold=*/true));
+ }
+ rewriter.replaceOp(op, results);
+ return mlir::success();
+ }
};
struct ShapeShiftOpConversion : public MustBeDeadConversion<fir::ShapeShiftOp> {
@@ -4880,14 +4929,14 @@ void fir::populateFIRToLLVMConversionPatterns(
LogicalOrOpConversion, MulcOpConversion, NegcOpConversion,
NeqvOpConversion, NoReassocOpConversion, PrefetchOpConversion,
SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion,
- SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion,
- ShiftOpConversion, SliceOpConversion, StoreOpConversion,
- StringLitOpConversion, SubcOpConversion, TypeDescOpConversion,
- TypeInfoOpConversion, UnboxCharOpConversion, UnboxProcOpConversion,
- UndefOpConversion, UnreachableOpConversion, UseStmtOpConversion,
- ModuleDebugImportsOpConversion, XArrayCoorOpConversion,
- XEmboxOpConversion, XReboxOpConversion, ZeroOpConversion>(converter,
- options);
+ SelectTypeOpConversion, ShapeOpConversion, ShapeExtentsOpConversion,
+ ShapeShiftOpConversion, ShiftOpConversion, SliceOpConversion,
+ StoreOpConversion, StringLitOpConversion, SubcOpConversion,
+ TypeDescOpConversion, TypeInfoOpConversion, UnboxCharOpConversion,
+ UnboxProcOpConversion, UndefOpConversion, UnreachableOpConversion,
+ UseStmtOpConversion, ModuleDebugImportsOpConversion,
+ XArrayCoorOpConversion, XEmboxOpConversion, XReboxOpConversion,
+ ZeroOpConversion>(converter, options);
// Patterns that are populated without a type converter do not trigger
// target materializations for the operands of the root op.
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index 2fca4111e0980..0cbf0aa259219 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -123,6 +123,12 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
addConversion([&](fir::SequenceType sequence) {
return convertSequenceType(sequence);
});
+ addConversion([&](fir::ShapeType shape) {
+ mlir::Type i64Ty = mlir::IntegerType::get(&getContext(), 64);
+ llvm::SmallVector<mlir::Type> members(shape.getRank(), i64Ty);
+ return mlir::LLVM::LLVMStructType::getLiteral(&getContext(), members,
+ /*isPacked=*/false);
+ });
addConversion([&](fir::TypeDescType tdesc) {
return convertTypeDescType(tdesc.getContext());
});
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 937f5c3f07e7d..c53d1b20456c7 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4798,6 +4798,51 @@ void fir::ShapeOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
build(builder, result, type, extents);
}
+//===----------------------------------------------------------------------===//
+// ShapeExtentsOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct FoldShapeExtentsOfShape
+ : public mlir::OpRewritePattern<fir::ShapeExtentsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(fir::ShapeExtentsOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ auto shapeOp = op.getShape().getDefiningOp<fir::ShapeOp>();
+ if (!shapeOp)
+ return mlir::failure();
+ rewriter.replaceOp(op, shapeOp.getExtents());
+ return mlir::success();
+ }
+};
+} // namespace
+
+llvm::LogicalResult fir::ShapeExtentsOp::verify() {
+ auto shapeTy = mlir::dyn_cast<fir::ShapeType>(getShape().getType());
+ if (!shapeTy)
+ return emitOpError("operand must be a !fir.shape type");
+ if (getNumResults() != shapeTy.getRank())
+ return emitOpError("number of results must match shape rank");
+ return mlir::success();
+}
+
+void fir::ShapeExtentsOp::build(mlir::OpBuilder &builder,
+ mlir::OperationState &result,
+ mlir::Value shape) {
+ auto shapeTy = mlir::cast<fir::ShapeType>(shape.getType());
+ mlir::Type indexTy = builder.getIndexType();
+ llvm::SmallVector<mlir::Type> resultTypes(shapeTy.getRank(), indexTy);
+ result.addTypes(resultTypes);
+ result.addOperands(shape);
+}
+
+void fir::ShapeExtentsOp::getCanonicalizationPatterns(
+ mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
+ patterns.add<FoldShapeExtentsOfShape>(context);
+}
+
//===----------------------------------------------------------------------===//
// ShapeShiftOp
//===----------------------------------------------------------------------===//
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index 2c27e34a1d5c2..3f2a049d8a89e 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -177,6 +177,13 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
void populateShape(SmallVectorImpl<Value> &vec, fir::ShapeOp shape) const;
+ /// Recover per-dimension extent SSA values from a shape operand. Inserts
+ /// `fir.shape_extents` when the defining `fir.shape` is not visible (e.g.
+ /// block argument from control-flow merge).
+ bool materializeShapeExtents(Value shapeVal, PatternRewriter &rewriter,
+ Location loc,
+ SmallVectorImpl<Value> &shapeVec) const;
+
static fir::SliceOp getSliceOp(Value sliceVal) {
return sliceVal ? sliceVal.getDefiningOp<fir::SliceOp>() : fir::SliceOp{};
}
@@ -316,6 +323,36 @@ void FIRToMemRef::populateShape(SmallVectorImpl<Value> &vec,
vec.append(shape.getExtents().begin(), shape.getExtents().end());
}
+bool FIRToMemRef::materializeShapeExtents(
+ Value shapeVal, PatternRewriter &rewriter, Location loc,
+ SmallVectorImpl<Value> &shapeVec) const {
+ if (!shapeVal)
+ return false;
+
+ while (auto convertOp = shapeVal.getDefiningOp<fir::ConvertOp>())
+ shapeVal = convertOp.getOperand();
+
+ if (auto shapeOp = shapeVal.getDefiningOp<fir::ShapeOp>()) {
+ shapeVec.append(shapeOp.getExtents().begin(), shapeOp.getExtents().end());
+ return true;
+ }
+
+ if (auto extentsOp = shapeVal.getDefiningOp<fir::ShapeExtentsOp>()) {
+ shapeVec.append(extentsOp.getExtents().begin(),
+ extentsOp.getExtents().end());
+ return true;
+ }
+
+ if (mlir::isa<fir::ShapeType>(shapeVal.getType())) {
+ auto extentsOp = fir::ShapeExtentsOp::create(rewriter, loc, shapeVal);
+ shapeVec.append(extentsOp.getExtents().begin(),
+ extentsOp.getExtents().end());
+ return true;
+ }
+
+ return false;
+}
+
template <typename OpTy>
void FIRToMemRef::collectSliceInfoFrom(OpTy op, SliceInfo &info) const {
if constexpr (std::is_same_v<OpTy, fir::ArrayCoorOp> ||
@@ -324,14 +361,14 @@ void FIRToMemRef::collectSliceInfoFrom(OpTy op, SliceInfo &info) const {
Value shapeVal = op.getShape();
if (shapeVal) {
- Operation *shapeValOp = shapeVal.getDefiningOp();
-
- if (auto shapeOp = dyn_cast<fir::ShapeOp>(shapeValOp)) {
- populateShape(info.shapeVec, shapeOp);
- } else if (auto shapeShiftOp = dyn_cast<fir::ShapeShiftOp>(shapeValOp)) {
- populateShapeAndShift(info.shapeVec, info.shiftVec, shapeShiftOp);
- } else if (auto shiftOp = dyn_cast<fir::ShiftOp>(shapeValOp)) {
- populateShift(info.shiftVec, shiftOp);
+ if (Operation *shapeValOp = shapeVal.getDefiningOp()) {
+ if (auto shapeOp = dyn_cast<fir::ShapeOp>(shapeValOp)) {
+ populateShape(info.shapeVec, shapeOp);
+ } else if (auto shapeShiftOp = dyn_cast<fir::ShapeShiftOp>(shapeValOp)) {
+ populateShapeAndShift(info.shapeVec, info.shiftVec, shapeShiftOp);
+ } else if (auto shiftOp = dyn_cast<fir::ShiftOp>(shapeValOp)) {
+ populateShift(info.shiftVec, shiftOp);
+ }
}
}
@@ -642,6 +679,20 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
rewriter.setInsertionPointAfter(arrayCoorOp);
}
+ SliceInfo sliceInfo;
+ collectSliceInfoFrom(arrayCoorOp, sliceInfo);
+ if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>())
+ collectSliceInfoFrom(embox, sliceInfo);
+ else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>())
+ collectSliceInfoFrom(rebox, sliceInfo);
+
+ if (!sliceInfo.hasProjectedSlice && sliceInfo.shapeVec.empty()) {
+ rewriter.setInsertionPoint(arrayCoorOp);
+ (void)materializeShapeExtents(arrayCoorOp.getShape(), rewriter, loc,
+ sliceInfo.shapeVec);
+ rewriter.setInsertionPointAfter(arrayCoorOp);
+ }
+
Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
FailureOr<SmallVector<Value>> failureOrIndices =
getMemrefIndices(arrayCoorOp, memref, rewriter, *converted, one);
@@ -662,12 +713,6 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
// For complex projections, reinterpret memref<d0×...×complex<T>> as
// memref<d0×...×2×T> and append the component index (0=re, 1=im) so that
// each load/store touches exactly sizeof(T) bytes.
- SliceInfo sliceInfo;
- collectSliceInfoFrom(arrayCoorOp, sliceInfo);
- if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>())
- collectSliceInfoFrom(embox, sliceInfo);
- else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>())
- collectSliceInfoFrom(rebox, sliceInfo);
auto srcTy = cast<MemRefType>((*converted).getType());
if (sliceInfo.hasProjectedSlice) {
if (auto complexTy = dyn_cast<mlir::ComplexType>(srcTy.getElementType())) {
@@ -729,6 +774,11 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
(isDescriptor && sliceInfo.shiftVec.empty() && arrayCoorOp.getShape() &&
arrayCoorOp.getSlice());
if (descriptorOwnsLayout) {
+ // Plain `!fir.ref` without recoverable shape extents cannot use fir.box_*.
+ if (shapeVec.empty() && !sliceInfo.hasProjectedSlice && !isDescriptor &&
+ !isRebox)
+ return failure();
+
// Projected slices carry their physical layout in the descriptor. Rebuild
// the MemRef view from box metadata instead of from slice triplets.
auto boxElementSize =
diff --git a/flang/test/Fir/convert-to-llvm-invalid.fir b/flang/test/Fir/convert-to-llvm-invalid.fir
index 9f003e4eb7d59..bd22dc81f05ca 100644
--- a/flang/test/Fir/convert-to-llvm-invalid.fir
+++ b/flang/test/Fir/convert-to-llvm-invalid.fir
@@ -3,18 +3,6 @@
// RUN: fir-opt --split-input-file --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" --verify-diagnostics %s
-// Test `fir.shape` conversion failure because the op has uses.
-
-func.func @shape_not_dead(%arg0: !fir.ref<!fir.array<?x?xf32>>, %i: index, %j: index) {
- %c0 = arith.constant 1 : index
- // expected-error at +1{{failed to legalize operation 'fir.shape'}}
- %0 = fir.shape %c0, %c0 : (index, index) -> !fir.shape<2>
- %1 = fir.array_coor %arg0(%0) %i, %j : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
- return
-}
-
-// -----
-
// Test `fir.slice` conversion failure because the op has uses.
func.func @slice_not_dead(%arg0: !fir.ref<!fir.array<?x?xf32>>, %i: index, %j: index) {
diff --git a/flang/test/Fir/shape-extents.mlir b/flang/test/Fir/shape-extents.mlir
new file mode 100644
index 0000000000000..3d7c17bf2fcb6
--- /dev/null
+++ b/flang/test/Fir/shape-extents.mlir
@@ -0,0 +1,42 @@
+// RUN: fir-opt %s | FileCheck %s --check-prefix=PLAIN
+// RUN: fir-opt %s -canonicalize | FileCheck %s --check-prefix=CANON
+
+// PLAIN-LABEL: func @fold_shape_extents
+// PLAIN: fir.shape_extents
+// CANON-LABEL: func @fold_shape_extents
+// CANON: fir.fake_use %[[N:arg[0-9]+]] : index
+func.func @fold_shape_extents(%n : index) {
+ %sh = fir.shape %n : (index) -> !fir.shape<1>
+ %e = fir.shape_extents %sh : (!fir.shape<1>) -> (index)
+ fir.fake_use %e : index
+ return
+}
+
+// PLAIN-LABEL: func @shape_extents_2d
+// PLAIN: fir.shape_extents
+// CANON-LABEL: func @shape_extents_2d
+// CANON: fir.fake_use %[[N1:arg[0-9]+]], %[[N2:arg[0-9]+]] : index, index
+func.func @shape_extents_2d(%n1 : index, %n2 : index) {
+ %sh = fir.shape %n1, %n2 : (index, index) -> !fir.shape<2>
+ %e0, %e1 = fir.shape_extents %sh : (!fir.shape<2>) -> (index, index)
+ fir.fake_use %e0, %e1 : index, index
+ return
+}
+
+// PLAIN-LABEL: func @shape_extents_block_arg
+// PLAIN: fir.shape_extents
+// CANON-LABEL: func @shape_extents_block_arg
+// CANON: fir.shape_extents
+func.func @shape_extents_block_arg(%pred : i1, %n1 : index, %n2 : index) {
+ cf.cond_br %pred, ^bb1, ^bb2
+^bb1:
+ %sh1 = fir.shape %n1 : (index) -> !fir.shape<1>
+ cf.br ^bb3(%sh1 : !fir.shape<1>)
+^bb2:
+ %sh2 = fir.shape %n2 : (index) -> !fir.shape<1>
+ cf.br ^bb3(%sh2 : !fir.shape<1>)
+^bb3(%phi : !fir.shape<1>):
+ %e = fir.shape_extents %phi : (!fir.shape<1>) -> (index)
+ fir.fake_use %e : index
+ return
+}
>From e1419420a5e115760315d772a5974277e11074f2 Mon Sep 17 00:00:00 2001
From: Yebin Chon <ychon at nvidia.com>
Date: Sat, 23 May 2026 09:46:14 -0700
Subject: [PATCH 2/6] add forwarded-shape test
---
.../FIRToMemRef/forwarded-shape.mlir | 24 +++++++++++++++++++
1 file changed, 24 insertions(+)
create mode 100644 flang/test/Transforms/FIRToMemRef/forwarded-shape.mlir
diff --git a/flang/test/Transforms/FIRToMemRef/forwarded-shape.mlir b/flang/test/Transforms/FIRToMemRef/forwarded-shape.mlir
new file mode 100644
index 0000000000000..b3e8cae4e00a3
--- /dev/null
+++ b/flang/test/Transforms/FIRToMemRef/forwarded-shape.mlir
@@ -0,0 +1,24 @@
+// RUN: fir-opt %s --fir-to-memref --allow-unregistered-dialect | FileCheck %s
+
+func.func @forwarded_shape_store(%pred : i1, %n1 : index, %n2 : index,
+ %arg0: !fir.ref<!fir.array<?xi32>>) {
+ cf.cond_br %pred, ^bb1, ^bb2
+^bb1:
+ %sh1 = fir.shape %n1 : (index) -> !fir.shape<1>
+ cf.br ^bb3(%sh1 : !fir.shape<1>)
+^bb2:
+ %sh2 = fir.shape %n2 : (index) -> !fir.shape<1>
+ cf.br ^bb3(%sh2 : !fir.shape<1>)
+^bb3(%phi : !fir.shape<1>):
+ %c1 = arith.constant 1 : index
+ %c42 = arith.constant 42 : i32
+ %elt = fir.array_coor %arg0(%phi) %c1
+ : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
+ fir.store %c42 to %elt : !fir.ref<i32>
+ return
+}
+
+// CHECK-LABEL: func.func @forwarded_shape_store
+// CHECK: fir.shape_extents
+// CHECK: memref.store
+// CHECK-NOT: fir.array_coor
>From f93e2a15fc5abe3269497aa5f215d7f198c09496 Mon Sep 17 00:00:00 2001
From: Yebin Chon <ychon at nvidia.com>
Date: Sat, 23 May 2026 10:16:39 -0700
Subject: [PATCH 3/6] formatting
---
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 11 ++++++-----
flang/lib/Optimizer/Transforms/FIRToMemRef.cpp | 7 ++++---
2 files changed, 10 insertions(+), 8 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 523ca462dde78..f71455aada159 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -4597,15 +4597,16 @@ struct ShapeOpConversion : public fir::FIROpConversion<fir::ShapeOp> {
for (auto [i, extent] : llvm::enumerate(adaptor.getExtents())) {
mlir::Value extentI64 =
integerCast(loc, rewriter, i64Ty, extent, /*fold=*/true);
- structVal =
- mlir::LLVM::InsertValueOp::create(rewriter, loc, structVal, extentI64, i);
+ structVal = mlir::LLVM::InsertValueOp::create(rewriter, loc, structVal,
+ extentI64, i);
}
rewriter.replaceOp(op, structVal);
return mlir::success();
}
};
-struct ShapeExtentsOpConversion : public fir::FIROpConversion<fir::ShapeExtentsOp> {
+struct ShapeExtentsOpConversion
+ : public fir::FIROpConversion<fir::ShapeExtentsOp> {
using FIROpConversion::FIROpConversion;
llvm::LogicalResult
@@ -4619,8 +4620,8 @@ struct ShapeExtentsOpConversion : public fir::FIROpConversion<fir::ShapeExtentsO
mlir::Value llvmShape = adaptor.getShape();
llvm::SmallVector<mlir::Value> results;
for (unsigned i = 0; i < op.getNumResults(); ++i) {
- mlir::Value extentI64 =
- mlir::LLVM::ExtractValueOp::create(rewriter, loc, i64Ty, llvmShape, i);
+ mlir::Value extentI64 = mlir::LLVM::ExtractValueOp::create(
+ rewriter, loc, i64Ty, llvmShape, i);
mlir::Type resultTy = convertType(op.getExtents()[i].getType());
results.push_back(
integerCast(loc, rewriter, resultTy, extentI64, /*fold=*/true));
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index 3f2a049d8a89e..c33a2103ab9e9 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -339,14 +339,14 @@ bool FIRToMemRef::materializeShapeExtents(
if (auto extentsOp = shapeVal.getDefiningOp<fir::ShapeExtentsOp>()) {
shapeVec.append(extentsOp.getExtents().begin(),
- extentsOp.getExtents().end());
+ extentsOp.getExtents().end());
return true;
}
if (mlir::isa<fir::ShapeType>(shapeVal.getType())) {
auto extentsOp = fir::ShapeExtentsOp::create(rewriter, loc, shapeVal);
shapeVec.append(extentsOp.getExtents().begin(),
- extentsOp.getExtents().end());
+ extentsOp.getExtents().end());
return true;
}
@@ -364,7 +364,8 @@ void FIRToMemRef::collectSliceInfoFrom(OpTy op, SliceInfo &info) const {
if (Operation *shapeValOp = shapeVal.getDefiningOp()) {
if (auto shapeOp = dyn_cast<fir::ShapeOp>(shapeValOp)) {
populateShape(info.shapeVec, shapeOp);
- } else if (auto shapeShiftOp = dyn_cast<fir::ShapeShiftOp>(shapeValOp)) {
+ } else if (auto shapeShiftOp =
+ dyn_cast<fir::ShapeShiftOp>(shapeValOp)) {
populateShapeAndShift(info.shapeVec, info.shiftVec, shapeShiftOp);
} else if (auto shiftOp = dyn_cast<fir::ShiftOp>(shapeValOp)) {
populateShift(info.shiftVec, shiftOp);
>From ca5285dc5ed70c4803087f52d716683fe67c6c58 Mon Sep 17 00:00:00 2001
From: Yebin Chon <ychon at nvidia.com>
Date: Tue, 2 Jun 2026 10:08:09 -0700
Subject: [PATCH 4/6] insert type casts for fir.shape and fir.shape_extents
type mismatch. Add failure point when shape cannot be recovered
---
flang/lib/Optimizer/Dialect/FIROps.cpp | 13 ++++++++++++-
flang/lib/Optimizer/Transforms/FIRToMemRef.cpp | 15 ++++++++-------
flang/test/Fir/shape-extents.mlir | 15 +++++++++++++++
3 files changed, 35 insertions(+), 8 deletions(-)
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index c53d1b20456c7..fa48bfcc39cb4 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4813,7 +4813,18 @@ struct FoldShapeExtentsOfShape
auto shapeOp = op.getShape().getDefiningOp<fir::ShapeOp>();
if (!shapeOp)
return mlir::failure();
- rewriter.replaceOp(op, shapeOp.getExtents());
+
+ llvm::SmallVector<mlir::Value> results;
+ for (auto [extent, resultType] : llvm::zip(shapeOp.getExtents(), op.getResultTypes())) {
+ if (extent.getType() == resultType) {
+ results.push_back(extent);
+ } else if (fir::ConvertOp::canBeConverted(extent.getType(), resultType)) {
+ results.push_back(fir::ConvertOp::create(rewriter, op.getLoc(), resultType, extent));
+ } else {
+ return mlir::failure();
+ }
+ }
+ rewriter.replaceOp(op, results);
return mlir::success();
}
};
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index c33a2103ab9e9..9d0199dbc5920 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -329,9 +329,6 @@ bool FIRToMemRef::materializeShapeExtents(
if (!shapeVal)
return false;
- while (auto convertOp = shapeVal.getDefiningOp<fir::ConvertOp>())
- shapeVal = convertOp.getOperand();
-
if (auto shapeOp = shapeVal.getDefiningOp<fir::ShapeOp>()) {
shapeVec.append(shapeOp.getExtents().begin(), shapeOp.getExtents().end());
return true;
@@ -688,10 +685,14 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
collectSliceInfoFrom(rebox, sliceInfo);
if (!sliceInfo.hasProjectedSlice && sliceInfo.shapeVec.empty()) {
- rewriter.setInsertionPoint(arrayCoorOp);
- (void)materializeShapeExtents(arrayCoorOp.getShape(), rewriter, loc,
- sliceInfo.shapeVec);
- rewriter.setInsertionPointAfter(arrayCoorOp);
+ auto shapeVal = arrayCoorOp.getShape();
+ if (shapeVal && mlir::isa<fir::ShapeType>(shapeVal.getType())) {
+ rewriter.setInsertionPoint(arrayCoorOp);
+ if(!materializeShapeExtents(shapeVal, rewriter, loc,
+ sliceInfo.shapeVec))
+ return failure();
+ rewriter.setInsertionPointAfter(arrayCoorOp);
+ }
}
Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
diff --git a/flang/test/Fir/shape-extents.mlir b/flang/test/Fir/shape-extents.mlir
index 3d7c17bf2fcb6..d35e2982ecaad 100644
--- a/flang/test/Fir/shape-extents.mlir
+++ b/flang/test/Fir/shape-extents.mlir
@@ -40,3 +40,18 @@ func.func @shape_extents_block_arg(%pred : i1, %n1 : index, %n2 : index) {
fir.fake_use %e : index
return
}
+
+// Check for proper insertion of casting when types of
+// fir.shape ops and fir.shape_extents results do not match
+// PLAIN-LABEL: func @fold_shape_extents_cast
+// PLAIN: fir.shape_extents
+// CANON-LABEL: func @fold_shape_extents_cast
+// CANON-NOT: fir.shape_extents
+// CANON: fir.convert %{{.*}} : (i64) -> index
+// CANON: fir.fake_use %{{.*}} : index
+func.func @fold_shape_extents_cast(%e : i64) {
+ %sh = fir.shape %e : (i64) -> !fir.shape<1>
+ %ext = fir.shape_extents %sh : (!fir.shape<1>) -> (index)
+ fir.fake_use %ext : index
+ return
+}
\ No newline at end of file
>From af185eb62891d27a9ada36168dda15a0cec8d092 Mon Sep 17 00:00:00 2001
From: Yebin Chon <ychon at nvidia.com>
Date: Thu, 4 Jun 2026 12:35:17 -0700
Subject: [PATCH 5/6] also takes fir.shape_shift op
---
.../include/flang/Optimizer/Dialect/FIROps.td | 6 +--
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 44 +++++++++++++++++--
flang/lib/Optimizer/CodeGen/TypeConverter.cpp | 6 +++
flang/lib/Optimizer/Dialect/FIROps.cpp | 38 +++++++++++-----
.../lib/Optimizer/Transforms/FIRToMemRef.cpp | 13 ++++--
5 files changed, 85 insertions(+), 22 deletions(-)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 3d0ac4398e7df..74e7937550b8f 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2088,10 +2088,10 @@ def fir_ShapeOp : fir_Op<"shape", [Pure]> {
}
def fir_ShapeExtentsOp : fir_Op<"shape_extents", [Pure]> {
- let summary = "unpack a `!fir.shape` into per-dimension extent SSA values";
+ let summary = "unpack a shape value into per-dimension extent SSA values";
let description = [{
- Takes a single abstract `!fir.shape<n>` value and yields `n` integer SSA
+ Takes a single abstract `!fir.shape<n>` or `!fir.shapeshift<n>` value and yields `n` integer SSA
results, one per dimension extent in Fortran row-to-column order. This is
intended for lowering when extent values are needed but the defining
`fir.shape` is not visible (for example when a shape is forwarded through
@@ -2105,7 +2105,7 @@ def fir_ShapeExtentsOp : fir_Op<"shape_extents", [Pure]> {
```
}];
- let arguments = (ins fir_ShapeType:$shape);
+ let arguments = (ins AnyShapeType:$shape);
let results = (outs Variadic<AnyIntegerType>:$extents);
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index f71455aada159..875a264ebbd03 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -4613,9 +4613,18 @@ struct ShapeExtentsOpConversion
matchAndRewrite(fir::ShapeExtentsOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
- auto shapeTy = mlir::cast<fir::ShapeType>(op.getShape().getType());
- if (shapeTy.getRank() != op.getNumResults())
+
+ mlir::Type ty = op.getShape().getType();
+ unsigned rank;
+ if (auto shapeTy = mlir::dyn_cast<fir::ShapeType>(ty))
+ rank = shapeTy.getRank();
+ else if (auto ssTy = mlir::dyn_cast<fir::ShapeShiftType>(ty))
+ rank = ssTy.getRank();
+ else
+ return mlir::failure();
+ if (rank != op.getNumResults())
return mlir::failure();
+
mlir::Type i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64);
mlir::Value llvmShape = adaptor.getShape();
llvm::SmallVector<mlir::Value> results;
@@ -4631,8 +4640,35 @@ struct ShapeExtentsOpConversion
}
};
-struct ShapeShiftOpConversion : public MustBeDeadConversion<fir::ShapeShiftOp> {
- using MustBeDeadConversion::MustBeDeadConversion;
+struct ShapeShiftOpConversion : public fir::FIROpConversion<fir::ShapeShiftOp> {
+ using FIROpConversion::FIROpConversion;
+
+ llvm::LogicalResult
+ matchAndRewrite(fir::ShapeShiftOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ if (op->use_empty()) {
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+ auto loc = op.getLoc();
+ auto ssTy = mlir::cast<fir::ShapeShiftType>(op.getType());
+ mlir::Type llvmTy = convertType(ssTy);
+ mlir::Type i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64);
+ mlir::Value structVal =
+ mlir::LLVM::UndefOp::create(rewriter, loc, llvmTy);
+ // Pack extent operands only; lower bounds are not part of the LLVM shape
+ // bundle consumed by fir.shape_extents.
+ for (auto [i, extent] : llvm::enumerate(adaptor.getPairs())) {
+ if (i & 1)
+ continue;
+ mlir::Value extentI64 =
+ integerCast(loc, rewriter, i64Ty, extent, /*fold=*/true);
+ structVal = mlir::LLVM::InsertValueOp::create(rewriter, loc, structVal,
+ extentI64, i / 2);
+ }
+ rewriter.replaceOp(op, structVal);
+ return mlir::success();
+ }
};
struct ShiftOpConversion : public MustBeDeadConversion<fir::ShiftOp> {
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index 0cbf0aa259219..7241839661d02 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -129,6 +129,12 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
return mlir::LLVM::LLVMStructType::getLiteral(&getContext(), members,
/*isPacked=*/false);
});
+ addConversion([&](fir::ShapeShiftType shapeShift) {
+ mlir::Type i64Ty = mlir::IntegerType::get(&getContext(), 64);
+ llvm::SmallVector<mlir::Type> members(shapeShift.getRank(), i64Ty);
+ return mlir::LLVM::LLVMStructType::getLiteral(&getContext(), members,
+ /*isPacked=*/false);
+ });
addConversion([&](fir::TypeDescType tdesc) {
return convertTypeDescType(tdesc.getContext());
});
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index fa48bfcc39cb4..0907c8e862db3 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4810,16 +4810,22 @@ struct FoldShapeExtentsOfShape
mlir::LogicalResult
matchAndRewrite(fir::ShapeExtentsOp op,
mlir::PatternRewriter &rewriter) const override {
- auto shapeOp = op.getShape().getDefiningOp<fir::ShapeOp>();
- if (!shapeOp)
+ mlir::Value shape = op.getShape();
+ mlir::ValueRange extents;
+ if (auto shapeOp = shape.getDefiningOp<fir::ShapeOp>())
+ extents = shapeOp.getExtents();
+ else if (auto ssOp = shape.getDefiningOp<fir::ShapeShiftOp>())
+ extents = ssOp.getExtents();
+ else
return mlir::failure();
-
llvm::SmallVector<mlir::Value> results;
- for (auto [extent, resultType] : llvm::zip(shapeOp.getExtents(), op.getResultTypes())) {
+ for (auto [extent, resultType] :
+ llvm::zip(extents, op.getResultTypes())) {
if (extent.getType() == resultType) {
results.push_back(extent);
} else if (fir::ConvertOp::canBeConverted(extent.getType(), resultType)) {
- results.push_back(fir::ConvertOp::create(rewriter, op.getLoc(), resultType, extent));
+ results.push_back(
+ fir::ConvertOp::create(rewriter, op.getLoc(), resultType, extent));
} else {
return mlir::failure();
}
@@ -4831,10 +4837,15 @@ struct FoldShapeExtentsOfShape
} // namespace
llvm::LogicalResult fir::ShapeExtentsOp::verify() {
- auto shapeTy = mlir::dyn_cast<fir::ShapeType>(getShape().getType());
- if (!shapeTy)
- return emitOpError("operand must be a !fir.shape type");
- if (getNumResults() != shapeTy.getRank())
+ mlir::Type ty = getShape().getType();
+ unsigned rank;
+ if (auto shapeTy = mlir::dyn_cast<fir::ShapeType>(ty))
+ rank = shapeTy.getRank();
+ else if (auto ssTy = mlir::dyn_cast<fir::ShapeShiftType>(ty))
+ rank = ssTy.getRank();
+ else
+ return emitOpError("operand must be !fir.shape or !fir.shapeshift");
+ if (getNumResults() != rank)
return emitOpError("number of results must match shape rank");
return mlir::success();
}
@@ -4842,9 +4853,14 @@ llvm::LogicalResult fir::ShapeExtentsOp::verify() {
void fir::ShapeExtentsOp::build(mlir::OpBuilder &builder,
mlir::OperationState &result,
mlir::Value shape) {
- auto shapeTy = mlir::cast<fir::ShapeType>(shape.getType());
+ mlir::Type ty = shape.getType();
+ unsigned rank;
+ if (auto shapeTy = mlir::dyn_cast<fir::ShapeType>(ty))
+ rank = shapeTy.getRank();
+ else
+ rank = mlir::cast<fir::ShapeShiftType>(ty).getRank();
mlir::Type indexTy = builder.getIndexType();
- llvm::SmallVector<mlir::Type> resultTypes(shapeTy.getRank(), indexTy);
+ llvm::SmallVector<mlir::Type> resultTypes(rank, indexTy);
result.addTypes(resultTypes);
result.addOperands(shape);
}
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index 9d0199dbc5920..040196562c503 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -178,8 +178,8 @@ class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
void populateShape(SmallVectorImpl<Value> &vec, fir::ShapeOp shape) const;
/// Recover per-dimension extent SSA values from a shape operand. Inserts
- /// `fir.shape_extents` when the defining `fir.shape` is not visible (e.g.
- /// block argument from control-flow merge).
+ /// `fir.shape_extents` when the defining `fir.shape` or `fir.shapeshift` is
+ /// not visible (e.g. block argument from control-flow merge).
bool materializeShapeExtents(Value shapeVal, PatternRewriter &rewriter,
Location loc,
SmallVectorImpl<Value> &shapeVec) const;
@@ -334,13 +334,18 @@ bool FIRToMemRef::materializeShapeExtents(
return true;
}
+ if (auto ssOp = shapeVal.getDefiningOp<fir::ShapeShiftOp>()) {
+ shapeVec.append(ssOp.getExtents().begin(), ssOp.getExtents().end());
+ return true;
+ }
+
if (auto extentsOp = shapeVal.getDefiningOp<fir::ShapeExtentsOp>()) {
shapeVec.append(extentsOp.getExtents().begin(),
extentsOp.getExtents().end());
return true;
}
- if (mlir::isa<fir::ShapeType>(shapeVal.getType())) {
+ if (mlir::isa<fir::ShapeType, fir::ShapeShiftType>(shapeVal.getType())) {
auto extentsOp = fir::ShapeExtentsOp::create(rewriter, loc, shapeVal);
shapeVec.append(extentsOp.getExtents().begin(),
extentsOp.getExtents().end());
@@ -686,7 +691,7 @@ FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
if (!sliceInfo.hasProjectedSlice && sliceInfo.shapeVec.empty()) {
auto shapeVal = arrayCoorOp.getShape();
- if (shapeVal && mlir::isa<fir::ShapeType>(shapeVal.getType())) {
+ if (shapeVal && mlir::isa<fir::ShapeType, fir::ShapeShiftType>(shapeVal.getType())) {
rewriter.setInsertionPoint(arrayCoorOp);
if(!materializeShapeExtents(shapeVal, rewriter, loc,
sliceInfo.shapeVec))
>From 0cd7187dc56c944cfe93c8de87c29d78f052eeb4 Mon Sep 17 00:00:00 2001
From: Yebin Chon <ychon at nvidia.com>
Date: Mon, 8 Jun 2026 09:47:06 -0700
Subject: [PATCH 6/6] add CodeGen tests. remove fir.shape_extents from invalid
checking test
---
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 6 +-
.../lib/Optimizer/Transforms/FIRToMemRef.cpp | 6 --
flang/test/Fir/convert-to-llvm-invalid.fir | 12 ---
flang/test/Fir/shape-to-llvm.mlir | 89 +++++++++++++++++++
4 files changed, 92 insertions(+), 21 deletions(-)
create mode 100644 flang/test/Fir/shape-to-llvm.mlir
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 875a264ebbd03..21050a8fd1a5e 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -4658,11 +4658,11 @@ struct ShapeShiftOpConversion : public fir::FIROpConversion<fir::ShapeShiftOp> {
mlir::LLVM::UndefOp::create(rewriter, loc, llvmTy);
// Pack extent operands only; lower bounds are not part of the LLVM shape
// bundle consumed by fir.shape_extents.
- for (auto [i, extent] : llvm::enumerate(adaptor.getPairs())) {
- if (i & 1)
+ for (auto [i, pair] : llvm::enumerate(adaptor.getPairs())) {
+ if (!(i & 1))
continue;
mlir::Value extentI64 =
- integerCast(loc, rewriter, i64Ty, extent, /*fold=*/true);
+ integerCast(loc, rewriter, i64Ty, pair, /*fold=*/true);
structVal = mlir::LLVM::InsertValueOp::create(rewriter, loc, structVal,
extentI64, i / 2);
}
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
index 040196562c503..be035095fc0a6 100644
--- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -339,12 +339,6 @@ bool FIRToMemRef::materializeShapeExtents(
return true;
}
- if (auto extentsOp = shapeVal.getDefiningOp<fir::ShapeExtentsOp>()) {
- shapeVec.append(extentsOp.getExtents().begin(),
- extentsOp.getExtents().end());
- return true;
- }
-
if (mlir::isa<fir::ShapeType, fir::ShapeShiftType>(shapeVal.getType())) {
auto extentsOp = fir::ShapeExtentsOp::create(rewriter, loc, shapeVal);
shapeVec.append(extentsOp.getExtents().begin(),
diff --git a/flang/test/Fir/convert-to-llvm-invalid.fir b/flang/test/Fir/convert-to-llvm-invalid.fir
index bd22dc81f05ca..b0c66e283bf5a 100644
--- a/flang/test/Fir/convert-to-llvm-invalid.fir
+++ b/flang/test/Fir/convert-to-llvm-invalid.fir
@@ -27,18 +27,6 @@ func.func @shift_not_dead(%arg0: !fir.box<!fir.array<?xf32>>, %i: index) {
// -----
-// Test `fir.shape_shift` conversion failure because the op has uses.
-
-func.func @shape_shift_not_dead(%arg0: !fir.ref<!fir.array<?x?xf32>>, %i: index, %j: index) {
- %c0 = arith.constant 1 : index
- // expected-error at +1{{failed to legalize operation 'fir.shape_shift'}}
- %0 = fir.shape_shift %c0, %c0, %c0, %c0 : (index, index, index, index) -> !fir.shapeshift<2>
- %1 = fir.array_coor %arg0(%0) %i, %j : (!fir.ref<!fir.array<?x?xf32>>, !fir.shapeshift<2>, index, index) -> !fir.ref<f32>
- return
-}
-
-// -----
-
// Test `fir.select_type` conversion to llvm.
// Should have been converted.
diff --git a/flang/test/Fir/shape-to-llvm.mlir b/flang/test/Fir/shape-to-llvm.mlir
new file mode 100644
index 0000000000000..91df929a39a4e
--- /dev/null
+++ b/flang/test/Fir/shape-to-llvm.mlir
@@ -0,0 +1,89 @@
+// RUN: fir-opt %s --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %s
+
+// ShapeOpConversion: live fir.shape lowers to an n-field LLVM struct.
+// CHECK-LABEL: llvm.func @live_shape(
+// CHECK-SAME: %[[N:.+]]: i64
+// CHECK: %[[UNDEF:.+]] = llvm.mlir.undef
+// CHECK: %[[STRUCT:.+]] = llvm.insertvalue %[[N]], %[[UNDEF]][0]
+// CHECK: llvm.extractvalue %[[STRUCT]][0]
+// CHECK-NOT: fir.shape
+// CHECK-NOT: fir.shape_extents
+// CHECK: llvm.return
+func.func @live_shape(%n : index) {
+ %sh = fir.shape %n : (index) -> !fir.shape<1>
+ %e = fir.shape_extents %sh : (!fir.shape<1>) -> index
+ %c0 = arith.constant 0 : index
+ %sink = arith.addi %e, %c0 : index
+ return
+}
+
+// -----
+
+// ShapeShiftOpConversion: packs extent operands only (n-field struct, not 2n).
+// ShapeExtentsOpConversion: extractvalue field 0 for rank-1 shapeshift.
+// CHECK-LABEL: llvm.func @live_shape_shift(
+// CHECK-SAME: %[[LB:.+]]: i64, %[[EXT:.+]]: i64
+// CHECK: %[[UNDEF:.+]] = llvm.mlir.undef
+// CHECK: %[[STRUCT:.+]] = llvm.insertvalue %[[EXT]], %[[UNDEF]][0]
+// CHECK-NOT: llvm.insertvalue %[[LB]]
+// CHECK: llvm.extractvalue %[[STRUCT]][0]
+// CHECK-NOT: fir.shape_shift
+// CHECK-NOT: fir.shape_extents
+// CHECK: llvm.return
+func.func @live_shape_shift(%lb : index, %ext : index) {
+ %ss = fir.shape_shift %lb, %ext : (index, index) -> !fir.shapeshift<1>
+ %e = fir.shape_extents %ss : (!fir.shapeshift<1>) -> index
+ %c0 = arith.constant 0 : index
+ %sink = arith.addi %e, %c0 : index
+ return
+}
+
+// -----
+
+// ShapeExtentsOpConversion on a forwarded !fir.shape block argument.
+// CHECK-LABEL: llvm.func @live_shape_extents_forwarded(
+// CHECK-SAME: %[[PRED:.+]]: i1, %[[N1:.+]]: i64, %[[N2:.+]]: i64
+// CHECK: llvm.cond_br %[[PRED]]
+// CHECK: %[[UNDEF1:.+]] = llvm.mlir.undef
+// CHECK: %[[S1:.+]] = llvm.insertvalue %[[N1]], %[[UNDEF1]][0]
+// CHECK: llvm.br {{.*}}(%[[S1]]
+// CHECK: %[[UNDEF2:.+]] = llvm.mlir.undef
+// CHECK: %[[S2:.+]] = llvm.insertvalue %[[N2]], %[[UNDEF2]][0]
+// CHECK: llvm.br {{.*}}(%[[S2]]
+// CHECK: llvm.extractvalue %{{.*}}[0]
+// CHECK-NOT: fir.shape
+// CHECK-NOT: fir.shape_extents
+// CHECK: llvm.return
+func.func @live_shape_extents_forwarded(%pred : i1, %n1 : index, %n2 : index) {
+ cf.cond_br %pred, ^bb1, ^bb2
+^bb1:
+ %sh1 = fir.shape %n1 : (index) -> !fir.shape<1>
+ cf.br ^bb3(%sh1 : !fir.shape<1>)
+^bb2:
+ %sh2 = fir.shape %n2 : (index) -> !fir.shape<1>
+ cf.br ^bb3(%sh2 : !fir.shape<1>)
+^bb3(%phi : !fir.shape<1>):
+ %e = fir.shape_extents %phi : (!fir.shape<1>) -> index
+ %c0 = arith.constant 0 : index
+ %sink = arith.addi %e, %c0 : index
+ return
+}
+
+// -----
+
+// 2-D shape: struct has two extent fields; shape_extents extracts [0] and [1].
+// CHECK-LABEL: llvm.func @live_shape_extents_2d(
+// CHECK: llvm.insertvalue {{.+}}[0]
+// CHECK: llvm.insertvalue {{.+}}[1]
+// CHECK: llvm.extractvalue {{.+}}[0]
+// CHECK: llvm.extractvalue {{.+}}[1]
+// CHECK-NOT: fir.shape
+// CHECK-NOT: fir.shape_extents
+func.func @live_shape_extents_2d(%n1 : index, %n2 : index) {
+ %sh = fir.shape %n1, %n2 : (index, index) -> !fir.shape<2>
+ %e0, %e1 = fir.shape_extents %sh : (!fir.shape<2>) -> (index, index)
+ %c0 = arith.constant 0 : index
+ %s0 = arith.addi %e0, %c0 : index
+ %s1 = arith.addi %e1, %c0 : index
+ return
+}
More information about the flang-commits
mailing list