[flang-commits] [flang] [flang] Inline hlfir.reshape as hlfir.elemental. (PR #124683)
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Tue Jan 28 12:25:31 PST 2025
https://github.com/vzakhari updated https://github.com/llvm/llvm-project/pull/124683
>From 5b6b2017ee404e00f2c9aad08b4a2c87b159d0ee Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Mon, 27 Jan 2025 18:58:56 -0800
Subject: [PATCH 1/3] [flang] Inline hlfir.reshape as hlfir.elemental.
This patch inlines hlfir.reshape for simple cases, such as
when there is no ORDER argument; and when PAD is present,
only the trivial types are handled.
---
.../Transforms/SimplifyHLFIRIntrinsics.cpp | 208 +++++++++++++++++
.../simplify-hlfir-intrinsics-reshape.fir | 216 ++++++++++++++++++
2 files changed, 424 insertions(+)
create mode 100644 flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index fe7ae0eeed3cc3..35071361fa16b8 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -951,6 +951,213 @@ class DotProductConversion
}
};
+class ReshapeAsElementalConversion
+ : public mlir::OpRewritePattern<hlfir::ReshapeOp> {
+public:
+ using mlir::OpRewritePattern<hlfir::ReshapeOp>::OpRewritePattern;
+
+ llvm::LogicalResult
+ matchAndRewrite(hlfir::ReshapeOp reshape,
+ mlir::PatternRewriter &rewriter) const override {
+ // Do not inline RESHAPE with ORDER yet. The runtime implementation
+ // may be good enough, unless the temporary creation overhead
+ // is high.
+ // TODO: If ORDER is constant, then we can still easily inline.
+ // TODO: If the result's rank is 1, then we can assume ORDER == (/1/).
+ if (reshape.getOrder())
+ return rewriter.notifyMatchFailure(reshape,
+ "RESHAPE with ORDER argument");
+
+ // Verify that the element types of ARRAY, PAD and the result
+ // match before doing any transformations.
+ hlfir::Entity result = hlfir::Entity{reshape};
+ hlfir::Entity array = hlfir::Entity{reshape.getArray()};
+ mlir::Type elementType = array.getFortranElementType();
+ if (result.getFortranElementType() != elementType)
+ return rewriter.notifyMatchFailure(
+ reshape, "ARRAY and result have different types");
+ mlir::Value pad = reshape.getPad();
+ if (pad && hlfir::getFortranElementType(pad.getType()) != elementType)
+ return rewriter.notifyMatchFailure(reshape,
+ "ARRAY and PAD have different types");
+
+ // TODO: selecting between ARRAY and PAD of non-trivial element types
+ // requires more work. We have to select between two references
+ // to elements in ARRAY and PAD. This requires conditional
+ // bufferization of the element, if ARRAY/PAD is an expression.
+ if (pad && !fir::isa_trivial(elementType))
+ return rewriter.notifyMatchFailure(reshape,
+ "PAD present with non-trivial type");
+
+ mlir::Location loc = reshape.getLoc();
+ fir::FirOpBuilder builder{rewriter, reshape.getOperation()};
+ // Assume that all the indices arithmetic does not overflow
+ // the IndexType.
+ builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nuw);
+
+ llvm::SmallVector<mlir::Value, 1> typeParams;
+ hlfir::genLengthParameters(loc, builder, array, typeParams);
+
+ // Fetch the extents of ARRAY, PAD and result beforehand.
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayExtents =
+ hlfir::genExtentsVector(loc, builder, array);
+
+ mlir::Value arraySize, padSize;
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padExtents;
+ if (pad) {
+ // If PAD is present, we have to use array size to start taking
+ // elements from the PAD array.
+ arraySize = computeArraySize(loc, builder, arrayExtents);
+
+ padExtents = hlfir::genExtentsVector(loc, builder, hlfir::Entity{pad});
+ // PAD size is needed to wrap around the linear index addressing
+ // the PAD array.
+ padSize = computeArraySize(loc, builder, padExtents);
+ }
+ hlfir::Entity shape = hlfir::Entity{reshape.getShape()};
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> resultExtents;
+ mlir::Type indexType = builder.getIndexType();
+ for (int idx = 0; idx < result.getRank(); ++idx)
+ resultExtents.push_back(hlfir::loadElementAt(
+ loc, builder, shape,
+ builder.createIntegerConstant(loc, indexType, idx + 1)));
+ auto resultShape = builder.create<fir::ShapeOp>(loc, resultExtents);
+
+ auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange inputIndices) -> hlfir::Entity {
+ mlir::Value linearIndex =
+ computeLinearIndex(loc, builder, resultExtents, inputIndices);
+ fir::IfOp ifOp;
+ if (pad) {
+ // PAD is present. Check if this element comes from the PAD array.
+ mlir::Value isInsideArray = builder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::ult, linearIndex, arraySize);
+ ifOp = builder.create<fir::IfOp>(loc, elementType, isInsideArray,
+ /*withElseRegion=*/true);
+
+ // In the 'else' block, return an element from the PAD.
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ // Subtract the ARRAY size from the zero-based linear index
+ // to get the zero-based linear index into PAD.
+ mlir::Value padLinearIndex =
+ builder.create<mlir::arith::SubIOp>(loc, linearIndex, arraySize);
+ // PAD wraps around, when additional elements are needed.
+ padLinearIndex =
+ builder.create<mlir::arith::RemUIOp>(loc, padLinearIndex, padSize);
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padIndices =
+ delinearizeIndex(loc, builder, padExtents, padLinearIndex);
+ mlir::Value padElement =
+ hlfir::loadElementAt(loc, builder, hlfir::Entity{pad}, padIndices);
+ builder.create<fir::ResultOp>(loc, padElement);
+
+ // In the 'then' block, return an element from the ARRAY.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ }
+
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayIndices =
+ delinearizeIndex(loc, builder, arrayExtents, linearIndex);
+ mlir::Value arrayElement =
+ hlfir::loadElementAt(loc, builder, array, arrayIndices);
+
+ if (ifOp) {
+ builder.create<fir::ResultOp>(loc, arrayElement);
+ builder.setInsertionPointAfter(ifOp);
+ arrayElement = ifOp.getResult(0);
+ }
+
+ return hlfir::Entity{arrayElement};
+ };
+ hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
+ loc, builder, elementType, resultShape, typeParams, genKernel,
+ /*isUnordered=*/true,
+ /*polymorphicMold=*/result.isPolymorphic() ? array : mlir::Value{},
+ reshape.getResult().getType());
+ assert(elementalOp.getResult().getType() == reshape.getResult().getType());
+ rewriter.replaceOp(reshape, elementalOp);
+ return mlir::success();
+ }
+
+private:
+ /// Compute zero-based linear index given an array extents
+ /// and one-based indices:
+ /// \p extents: [e0, e1, ..., en]
+ /// \p indices: [i0, i1, ..., in]
+ ///
+ /// linear-index :=
+ /// (...((in-1)*e(n-1)+(i(n-1)-1))*e(n-2)+...)*e0+(i0-1)
+ static mlir::Value computeLinearIndex(mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ mlir::ValueRange extents,
+ mlir::ValueRange indices) {
+ std::size_t rank = extents.size();
+ assert(rank = indices.size());
+ mlir::Type indexType = builder.getIndexType();
+ mlir::Value zero = builder.createIntegerConstant(loc, indexType, 0);
+ mlir::Value one = builder.createIntegerConstant(loc, indexType, 1);
+ mlir::Value linearIndex = zero;
+ for (auto idx : llvm::enumerate(llvm::reverse(indices))) {
+ mlir::Value tmp = builder.create<mlir::arith::SubIOp>(
+ loc, builder.createConvert(loc, indexType, idx.value()), one);
+ tmp = builder.create<mlir::arith::AddIOp>(loc, linearIndex, tmp);
+ if (idx.index() + 1 < rank)
+ tmp = builder.create<mlir::arith::MulIOp>(
+ loc, tmp,
+ builder.createConvert(loc, indexType,
+ extents[rank - idx.index() - 2]));
+
+ linearIndex = tmp;
+ }
+ return linearIndex;
+ }
+
+ /// Compute one-based array indices from the given zero-based \p linearIndex
+ /// and the array \p extents [e0, e1, ..., en].
+ /// i0 := linearIndex % e0 + 1
+ /// linearIndex := linearIndex / e0
+ /// i1 := linearIndex % e1 + 1
+ /// linearIndex := linearIndex / e1
+ /// ...
+ /// i(n-1) := linearIndex % e(n-1) + 1
+ /// linearIndex := linearIndex / e(n-1)
+ /// in := linearIndex + 1
+ static llvm::SmallVector<mlir::Value, Fortran::common::maxRank>
+ delinearizeIndex(mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange extents, mlir::Value linearIndex) {
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
+ mlir::Type indexType = builder.getIndexType();
+ mlir::Value one = builder.createIntegerConstant(loc, indexType, 1);
+ linearIndex = builder.createConvert(loc, indexType, linearIndex);
+
+ for (std::size_t dim = 0; dim < extents.size(); ++dim) {
+ mlir::Value currentIndex;
+ if (dim == extents.size() - 1) {
+ currentIndex = linearIndex;
+ } else {
+ mlir::Value extent =
+ builder.createConvert(loc, indexType, extents[dim]);
+ currentIndex =
+ builder.create<mlir::arith::RemUIOp>(loc, linearIndex, extent);
+ linearIndex =
+ builder.create<mlir::arith::DivUIOp>(loc, linearIndex, extent);
+ }
+ indices.push_back(
+ builder.create<mlir::arith::AddIOp>(loc, currentIndex, one));
+ }
+ return indices;
+ }
+
+ static mlir::Value computeArraySize(mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ mlir::ValueRange extents) {
+ mlir::Type indexType = builder.getIndexType();
+ mlir::Value size = builder.createIntegerConstant(loc, indexType, 1);
+ for (auto extent : extents)
+ size = builder.create<mlir::arith::MulIOp>(
+ loc, size, builder.createConvert(loc, indexType, extent));
+ return size;
+ }
+};
+
class SimplifyHLFIRIntrinsics
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
public:
@@ -987,6 +1194,7 @@ class SimplifyHLFIRIntrinsics
patterns.insert<MatmulConversion<hlfir::MatmulOp>>(context);
patterns.insert<DotProductConversion>(context);
+ patterns.insert<ReshapeAsElementalConversion>(context);
if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir
new file mode 100644
index 00000000000000..ad8093335556c0
--- /dev/null
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir
@@ -0,0 +1,216 @@
+// Test hlfir.reshape simplification to hlfir.elemental:
+// RUN: fir-opt --simplify-hlfir-intrinsics %s | FileCheck %s
+
+func.func @reshape_simple(%arg0: !fir.box<!fir.array<?xf32>>, %arg1: !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?xf32> {
+ %res = hlfir.reshape %arg0 %arg1 : (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?xf32>
+ return %res : !hlfir.expr<?xf32>
+}
+// CHECK-LABEL: func.func @reshape_simple(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.box<!fir.array<?xf32>>,
+// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?xf32> {
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_2]]) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
+// CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_4]] : !fir.ref<i32>
+// CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (i32) -> !fir.shape<1>
+// CHECK: %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// CHECK: ^bb0(%[[VAL_8:.*]]: index):
+// CHECK: %[[VAL_9:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_3]] : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+// CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_9]]#0, %[[VAL_2]] overflow<nuw> : index
+// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_8]], %[[VAL_10]] overflow<nuw> : index
+// CHECK: %[[VAL_12:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_11]]) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// CHECK: %[[VAL_13:.*]] = fir.load %[[VAL_12]] : !fir.ref<f32>
+// CHECK: hlfir.yield_element %[[VAL_13]] : f32
+// CHECK: }
+// CHECK: return %[[VAL_7]] : !hlfir.expr<?xf32>
+// CHECK: }
+
+func.func @reshape_with_pad(%arg0: !fir.box<!fir.array<?x?x?xf32>>, %arg1: !fir.ref<!fir.array<2xi32>>, %arg2: !fir.box<!fir.array<?x?x?xf32>>) -> !hlfir.expr<?x?xf32> {
+ %res = hlfir.reshape %arg0 %arg1 pad %arg2 : (!fir.box<!fir.array<?x?x?xf32>>, !fir.ref<!fir.array<2xi32>>, !fir.box<!fir.array<?x?x?xf32>>) -> !hlfir.expr<?x?xf32>
+ return %res : !hlfir.expr<?x?xf32>
+}
+// CHECK-LABEL: func.func @reshape_with_pad(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.box<!fir.array<?x?x?xf32>>,
+// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.array<2xi32>>,
+// CHECK-SAME: %[[VAL_2:.*]]: !fir.box<!fir.array<?x?x?xf32>>) -> !hlfir.expr<?x?xf32> {
+// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK: %[[ARRAY_DIM0:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_5]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK: %[[ARRAY_DIM1:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_4]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK: %[[ARRAY_DIM2:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK: %[[VAL_9:.*]] = arith.muli %[[ARRAY_DIM0]]#1, %[[ARRAY_DIM1]]#1 overflow<nuw> : index
+// CHECK: %[[ARRAY_SIZE:.*]] = arith.muli %[[VAL_9]], %[[ARRAY_DIM2]]#1 overflow<nuw> : index
+// CHECK: %[[PAD_DIM0:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_5]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK: %[[PAD_DIM1:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_4]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK: %[[PAD_DIM2:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK: %[[VAL_14:.*]] = arith.muli %[[PAD_DIM0]]#1, %[[PAD_DIM1]]#1 overflow<nuw> : index
+// CHECK: %[[PAD_SIZE:.*]] = arith.muli %[[VAL_14]], %[[PAD_DIM2]]#1 overflow<nuw> : index
+// CHECK: %[[VAL_16:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_4]]) : (!fir.ref<!fir.array<2xi32>>, index) -> !fir.ref<i32>
+// CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK: %[[VAL_18:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_3]]) : (!fir.ref<!fir.array<2xi32>>, index) -> !fir.ref<i32>
+// CHECK: %[[VAL_19:.*]] = fir.load %[[VAL_18]] : !fir.ref<i32>
+// CHECK: %[[VAL_20:.*]] = fir.shape %[[VAL_17]], %[[VAL_19]] : (i32, i32) -> !fir.shape<2>
+// CHECK: %[[VAL_21:.*]] = hlfir.elemental %[[VAL_20]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xf32> {
+// CHECK: ^bb0(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index):
+// CHECK: %[[VAL_24:.*]] = arith.subi %[[VAL_23]], %[[VAL_4]] overflow<nuw> : index
+// CHECK: %[[VAL_25:.*]] = fir.convert %[[VAL_17]] : (i32) -> index
+// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_24]], %[[VAL_25]] overflow<nuw> : index
+// CHECK: %[[VAL_27:.*]] = arith.subi %[[VAL_22]], %[[VAL_4]] overflow<nuw> : index
+// CHECK: %[[LINEAR_INDEX:.*]] = arith.addi %[[VAL_26]], %[[VAL_27]] overflow<nuw> : index
+// CHECK: %[[IS_WITHIN_ARRAY:.*]] = arith.cmpi ult, %[[LINEAR_INDEX]], %[[ARRAY_SIZE]] : index
+// CHECK: %[[VAL_30:.*]] = fir.if %[[IS_WITHIN_ARRAY]] -> (f32) {
+// CHECK: %[[VAL_31:.*]] = arith.remui %[[LINEAR_INDEX]], %[[ARRAY_DIM0]]#1 : index
+// CHECK: %[[VAL_32:.*]] = arith.divui %[[LINEAR_INDEX]], %[[ARRAY_DIM0]]#1 : index
+// CHECK: %[[ARRAY_IDX0:.*]] = arith.addi %[[VAL_31]], %[[VAL_4]] overflow<nuw> : index
+// CHECK: %[[VAL_34:.*]] = arith.remui %[[VAL_32]], %[[ARRAY_DIM1]]#1 : index
+// CHECK: %[[VAL_35:.*]] = arith.divui %[[VAL_32]], %[[ARRAY_DIM1]]#1 : index
+// CHECK: %[[ARRAY_IDX1:.*]] = arith.addi %[[VAL_34]], %[[VAL_4]] overflow<nuw> : index
+// CHECK: %[[ARRAY_IDX2:.*]] = arith.addi %[[VAL_35]], %[[VAL_4]] overflow<nuw> : index
+// CHECK: %[[VAL_38:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_5]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK: %[[VAL_39:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_4]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK: %[[VAL_40:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK: %[[VAL_41:.*]] = arith.subi %[[VAL_38]]#0, %[[VAL_4]] overflow<nuw> : index
+// CHECK: %[[VAL_42:.*]] = arith.addi %[[ARRAY_IDX0]], %[[VAL_41]] overflow<nuw> : index
+// CHECK: %[[VAL_43:.*]] = arith.subi %[[VAL_39]]#0, %[[VAL_4]] overflow<nuw> : index
+// CHECK: %[[VAL_44:.*]] = arith.addi %[[ARRAY_IDX1]], %[[VAL_43]] overflow<nuw> : index
+// CHECK: %[[VAL_45:.*]] = arith.subi %[[VAL_40]]#0, %[[VAL_4]] overflow<nuw> : index
+// CHECK: %[[VAL_46:.*]] = arith.addi %[[ARRAY_IDX2]], %[[VAL_45]] overflow<nuw> : index
+// CHECK: %[[VAL_47:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_42]], %[[VAL_44]], %[[VAL_46]]) : (!fir.box<!fir.array<?x?x?xf32>>, index, index, index) -> !fir.ref<f32>
+// CHECK: %[[VAL_48:.*]] = fir.load %[[VAL_47]] : !fir.ref<f32>
+// CHECK: fir.result %[[VAL_48]] : f32
+// CHECK: } else {
+// CHECK: %[[PAD_LINEAR_INDEX:.*]] = arith.subi %[[LINEAR_INDEX]], %[[ARRAY_SIZE]] overflow<nuw> : index
+// CHECK: %[[PAD_LINEAR_INDEX_MOD:.*]] = arith.remui %[[PAD_LINEAR_INDEX]], %[[PAD_SIZE]] : index
+// CHECK: %[[VAL_51:.*]] = arith.remui %[[PAD_LINEAR_INDEX_MOD]], %[[PAD_DIM0]]#1 : index
+// CHECK: %[[VAL_52:.*]] = arith.divui %[[PAD_LINEAR_INDEX_MOD]], %[[PAD_DIM0]]#1 : index
+// CHECK: %[[PAD_IDX0:.*]] = arith.addi %[[VAL_51]], %[[VAL_4]] overflow<nuw> : index
+// CHECK: %[[VAL_54:.*]] = arith.remui %[[VAL_52]], %[[PAD_DIM1]]#1 : index
+// CHECK: %[[VAL_55:.*]] = arith.divui %[[VAL_52]], %[[PAD_DIM1]]#1 : index
+// CHECK: %[[PAD_IDX1:.*]] = arith.addi %[[VAL_54]], %[[VAL_4]] overflow<nuw> : index
+// CHECK: %[[PAD_IDX2:.*]] = arith.addi %[[VAL_55]], %[[VAL_4]] overflow<nuw> : index
+// CHECK: %[[VAL_58:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_5]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK: %[[VAL_59:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_4]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK: %[[VAL_60:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK: %[[VAL_61:.*]] = arith.subi %[[VAL_58]]#0, %[[VAL_4]] overflow<nuw> : index
+// CHECK: %[[VAL_62:.*]] = arith.addi %[[PAD_IDX0]], %[[VAL_61]] overflow<nuw> : index
+// CHECK: %[[VAL_63:.*]] = arith.subi %[[VAL_59]]#0, %[[VAL_4]] overflow<nuw> : index
+// CHECK: %[[VAL_64:.*]] = arith.addi %[[PAD_IDX1]], %[[VAL_63]] overflow<nuw> : index
+// CHECK: %[[VAL_65:.*]] = arith.subi %[[VAL_60]]#0, %[[VAL_4]] overflow<nuw> : index
+// CHECK: %[[VAL_66:.*]] = arith.addi %[[PAD_IDX2]], %[[VAL_65]] overflow<nuw> : index
+// CHECK: %[[VAL_67:.*]] = hlfir.designate %[[VAL_2]] (%[[VAL_62]], %[[VAL_64]], %[[VAL_66]]) : (!fir.box<!fir.array<?x?x?xf32>>, index, index, index) -> !fir.ref<f32>
+// CHECK: %[[VAL_68:.*]] = fir.load %[[VAL_67]] : !fir.ref<f32>
+// CHECK: fir.result %[[VAL_68]] : f32
+// CHECK: }
+// CHECK: hlfir.yield_element %[[VAL_30]] : f32
+// CHECK: }
+// CHECK: return %[[VAL_21]] : !hlfir.expr<?x?xf32>
+// CHECK: }
+
+func.func @reshape_derived_obj(%arg0: !fir.ref<!fir.array<10x!fir.type<whatever>>>, %arg1: !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.type<whatever>> {
+ %res = hlfir.reshape %arg0 %arg1 : (!fir.ref<!fir.array<10x!fir.type<whatever>>>, !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.type<whatever>>
+ return %res : !hlfir.expr<?x!fir.type<whatever>>
+}
+// CHECK-LABEL: func.func @reshape_derived_obj(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<10x!fir.type<whatever>>>,
+// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.type<whatever>> {
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_3:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_2]]) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
+// CHECK: %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref<i32>
+// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (i32) -> !fir.shape<1>
+// CHECK: %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<?x!fir.type<whatever>> {
+// CHECK: ^bb0(%[[VAL_7:.*]]: index):
+// CHECK: %[[VAL_8:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_7]]) : (!fir.ref<!fir.array<10x!fir.type<whatever>>>, index) -> !fir.ref<!fir.type<whatever>>
+// CHECK: hlfir.yield_element %[[VAL_8]] : !fir.ref<!fir.type<whatever>>
+// CHECK: }
+// CHECK: return %[[VAL_6]] : !hlfir.expr<?x!fir.type<whatever>>
+// CHECK: }
+
+func.func @reshape_derived_expr(%arg0: !hlfir.expr<?x!fir.type<whatever>>, %arg1: !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.type<whatever>> {
+ %res = hlfir.reshape %arg0 %arg1 : (!hlfir.expr<?x!fir.type<whatever>>, !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.type<whatever>>
+ return %res : !hlfir.expr<?x!fir.type<whatever>>
+}
+// CHECK-LABEL: func.func @reshape_derived_expr(
+// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<?x!fir.type<whatever>>,
+// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.type<whatever>> {
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_3:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_2]]) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
+// CHECK: %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref<i32>
+// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (i32) -> !fir.shape<1>
+// CHECK: %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<?x!fir.type<whatever>> {
+// CHECK: ^bb0(%[[VAL_7:.*]]: index):
+// CHECK: %[[VAL_8:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_7]] : (!hlfir.expr<?x!fir.type<whatever>>, index) -> !hlfir.expr<!fir.type<whatever>>
+// CHECK: hlfir.yield_element %[[VAL_8]] : !hlfir.expr<!fir.type<whatever>>
+// CHECK: }
+// CHECK: return %[[VAL_6]] : !hlfir.expr<?x!fir.type<whatever>>
+// CHECK: }
+
+func.func @reshape_poly_obj(%arg0: !fir.class<!fir.array<?x!fir.type<whatever>>>, %arg1: !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.type<whatever>?> {
+ %res = hlfir.reshape %arg0 %arg1 : (!fir.class<!fir.array<?x!fir.type<whatever>>>, !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.type<whatever>?>
+ return %res : !hlfir.expr<?x!fir.type<whatever>?>
+}
+// CHECK-LABEL: func.func @reshape_poly_obj(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.class<!fir.array<?x!fir.type<whatever>>>,
+// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.type<whatever>?> {
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_2]]) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
+// CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_4]] : !fir.ref<i32>
+// CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (i32) -> !fir.shape<1>
+// CHECK: %[[VAL_7:.*]] = hlfir.elemental %[[VAL_6]] mold %[[VAL_0]] unordered : (!fir.shape<1>, !fir.class<!fir.array<?x!fir.type<whatever>>>) -> !hlfir.expr<?x!fir.type<whatever>?> {
+// CHECK: ^bb0(%[[VAL_8:.*]]: index):
+// CHECK: %[[VAL_9:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_3]] : (!fir.class<!fir.array<?x!fir.type<whatever>>>, index) -> (index, index, index)
+// CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_9]]#0, %[[VAL_2]] overflow<nuw> : index
+// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_8]], %[[VAL_10]] overflow<nuw> : index
+// CHECK: %[[VAL_12:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_11]]) : (!fir.class<!fir.array<?x!fir.type<whatever>>>, index) -> !fir.class<!fir.type<whatever>>
+// CHECK: hlfir.yield_element %[[VAL_12]] : !fir.class<!fir.type<whatever>>
+// CHECK: }
+// CHECK: return %[[VAL_7]] : !hlfir.expr<?x!fir.type<whatever>?>
+// CHECK: }
+
+func.func @reshape_poly_expr(%arg0: !hlfir.expr<?x!fir.type<whatever>?>, %arg1: !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.type<whatever>?> {
+ %res = hlfir.reshape %arg0 %arg1 : (!hlfir.expr<?x!fir.type<whatever>?>, !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.type<whatever>?>
+ return %res : !hlfir.expr<?x!fir.type<whatever>?>
+}
+// CHECK-LABEL: func.func @reshape_poly_expr(
+// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<?x!fir.type<whatever>?>,
+// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.type<whatever>?> {
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_3:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_2]]) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
+// CHECK: %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref<i32>
+// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (i32) -> !fir.shape<1>
+// CHECK: %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] mold %[[VAL_0]] unordered : (!fir.shape<1>, !hlfir.expr<?x!fir.type<whatever>?>) -> !hlfir.expr<?x!fir.type<whatever>?> {
+// CHECK: ^bb0(%[[VAL_7:.*]]: index):
+// CHECK: %[[VAL_8:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_7]] : (!hlfir.expr<?x!fir.type<whatever>?>, index) -> !hlfir.expr<!fir.type<whatever>?>
+// CHECK: hlfir.yield_element %[[VAL_8]] : !hlfir.expr<!fir.type<whatever>?>
+// CHECK: }
+// CHECK: return %[[VAL_6]] : !hlfir.expr<?x!fir.type<whatever>?>
+// CHECK: }
+
+func.func @reshape_char(%arg0: !fir.box<!fir.array<?x!fir.char<2,?>>>, %arg1: !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.char<2,?>> {
+ %res = hlfir.reshape %arg0 %arg1 : (!fir.box<!fir.array<?x!fir.char<2,?>>>, !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.char<2,?>>
+ return %res : !hlfir.expr<?x!fir.char<2,?>>
+}
+// CHECK-LABEL: func.func @reshape_char(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.box<!fir.array<?x!fir.char<2,?>>>,
+// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.char<2,?>> {
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_5:.*]] = fir.box_elesize %[[VAL_0]] : (!fir.box<!fir.array<?x!fir.char<2,?>>>) -> index
+// CHECK: %[[VAL_6:.*]] = arith.divsi %[[VAL_5]], %[[VAL_4]] : index
+// CHECK: %[[VAL_7:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_2]]) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
+// CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
+// CHECK: %[[VAL_9:.*]] = fir.shape %[[VAL_8]] : (i32) -> !fir.shape<1>
+// CHECK: %[[VAL_10:.*]] = hlfir.elemental %[[VAL_9]] typeparams %[[VAL_6]] unordered : (!fir.shape<1>, index) -> !hlfir.expr<?x!fir.char<2,?>> {
+// CHECK: ^bb0(%[[VAL_11:.*]]: index):
+// CHECK: %[[VAL_12:.*]] = fir.box_elesize %[[VAL_0]] : (!fir.box<!fir.array<?x!fir.char<2,?>>>) -> index
+// CHECK: %[[VAL_13:.*]] = arith.divsi %[[VAL_12]], %[[VAL_4]] : index
+// CHECK: %[[VAL_14:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_3]] : (!fir.box<!fir.array<?x!fir.char<2,?>>>, index) -> (index, index, index)
+// CHECK: %[[VAL_15:.*]] = arith.subi %[[VAL_14]]#0, %[[VAL_2]] overflow<nuw> : index
+// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_11]], %[[VAL_15]] overflow<nuw> : index
+// CHECK: %[[VAL_17:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_16]]) typeparams %[[VAL_13]] : (!fir.box<!fir.array<?x!fir.char<2,?>>>, index, index) -> !fir.boxchar<2>
+// CHECK: hlfir.yield_element %[[VAL_17]] : !fir.boxchar<2>
+// CHECK: }
+// CHECK: return %[[VAL_10]] : !hlfir.expr<?x!fir.char<2,?>>
+// CHECK: }
>From 30e193f90582bcc31ea426844f84c79be285cabb Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Tue, 28 Jan 2025 12:02:55 -0800
Subject: [PATCH 2/3] Fixed handling of dynamically optional PAD.
---
.../Transforms/SimplifyHLFIRIntrinsics.cpp | 58 ++++++++++---------
.../simplify-hlfir-intrinsics-reshape.fir | 16 +++--
2 files changed, 37 insertions(+), 37 deletions(-)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 35071361fa16b8..bac94c9e5fd1b8 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -1002,18 +1002,10 @@ class ReshapeAsElementalConversion
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayExtents =
hlfir::genExtentsVector(loc, builder, array);
- mlir::Value arraySize, padSize;
- llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padExtents;
- if (pad) {
- // If PAD is present, we have to use array size to start taking
- // elements from the PAD array.
- arraySize = computeArraySize(loc, builder, arrayExtents);
-
- padExtents = hlfir::genExtentsVector(loc, builder, hlfir::Entity{pad});
- // PAD size is needed to wrap around the linear index addressing
- // the PAD array.
- padSize = computeArraySize(loc, builder, padExtents);
- }
+ // If PAD is present, we have to use array size to start taking
+ // elements from the PAD array.
+ mlir::Value arraySize =
+ pad ? computeArraySize(loc, builder, arrayExtents) : nullptr;
hlfir::Entity shape = hlfir::Entity{reshape.getShape()};
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> resultExtents;
mlir::Type indexType = builder.getIndexType();
@@ -1037,15 +1029,18 @@ class ReshapeAsElementalConversion
// In the 'else' block, return an element from the PAD.
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ // PAD is dynamically optional, but we can unconditionally access it
+ // in the 'else' block. If we have to start taking elements from it,
+ // then it must be present in a valid program.
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padExtents =
+ hlfir::genExtentsVector(loc, builder, hlfir::Entity{pad});
// Subtract the ARRAY size from the zero-based linear index
// to get the zero-based linear index into PAD.
mlir::Value padLinearIndex =
builder.create<mlir::arith::SubIOp>(loc, linearIndex, arraySize);
- // PAD wraps around, when additional elements are needed.
- padLinearIndex =
- builder.create<mlir::arith::RemUIOp>(loc, padLinearIndex, padSize);
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padIndices =
- delinearizeIndex(loc, builder, padExtents, padLinearIndex);
+ delinearizeIndex(loc, builder, padExtents, padLinearIndex,
+ /*wrapAround=*/true);
mlir::Value padElement =
hlfir::loadElementAt(loc, builder, hlfir::Entity{pad}, padIndices);
builder.create<fir::ResultOp>(loc, padElement);
@@ -1055,7 +1050,8 @@ class ReshapeAsElementalConversion
}
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayIndices =
- delinearizeIndex(loc, builder, arrayExtents, linearIndex);
+ delinearizeIndex(loc, builder, arrayExtents, linearIndex,
+ /*wrapAround=*/false);
mlir::Value arrayElement =
hlfir::loadElementAt(loc, builder, array, arrayIndices);
@@ -1119,33 +1115,39 @@ class ReshapeAsElementalConversion
/// ...
/// i(n-1) := linearIndex % e(n-1) + 1
/// linearIndex := linearIndex / e(n-1)
- /// in := linearIndex + 1
+ /// if (wrapAround) {
+ /// // If the index is allowed to wrap around, then
+ /// // we need to modulo it by the last dimension's extent.
+ /// in := linearIndex % en + 1
+ /// } else {
+ /// in := linearIndex + 1
+ /// }
static llvm::SmallVector<mlir::Value, Fortran::common::maxRank>
delinearizeIndex(mlir::Location loc, fir::FirOpBuilder &builder,
- mlir::ValueRange extents, mlir::Value linearIndex) {
+ mlir::ValueRange extents, mlir::Value linearIndex,
+ bool wrapAround) {
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
mlir::Type indexType = builder.getIndexType();
mlir::Value one = builder.createIntegerConstant(loc, indexType, 1);
linearIndex = builder.createConvert(loc, indexType, linearIndex);
for (std::size_t dim = 0; dim < extents.size(); ++dim) {
- mlir::Value currentIndex;
- if (dim == extents.size() - 1) {
- currentIndex = linearIndex;
- } else {
- mlir::Value extent =
- builder.createConvert(loc, indexType, extents[dim]);
+ mlir::Value extent = builder.createConvert(loc, indexType, extents[dim]);
+ // Avoid the modulo for the last index, unless wrap around is allowed.
+ mlir::Value currentIndex = linearIndex;
+ if (dim != extents.size() - 1 || wrapAround)
currentIndex =
builder.create<mlir::arith::RemUIOp>(loc, linearIndex, extent);
- linearIndex =
- builder.create<mlir::arith::DivUIOp>(loc, linearIndex, extent);
- }
+ // The result of the last division is unused, so it will be DCEd.
+ linearIndex =
+ builder.create<mlir::arith::DivUIOp>(loc, linearIndex, extent);
indices.push_back(
builder.create<mlir::arith::AddIOp>(loc, currentIndex, one));
}
return indices;
}
+ /// Return size of an array given its extents.
static mlir::Value computeArraySize(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::ValueRange extents) {
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir
index ad8093335556c0..4abf24e1efe1f6 100644
--- a/flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir
@@ -41,11 +41,6 @@ func.func @reshape_with_pad(%arg0: !fir.box<!fir.array<?x?x?xf32>>, %arg1: !fir.
// CHECK: %[[ARRAY_DIM2:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
// CHECK: %[[VAL_9:.*]] = arith.muli %[[ARRAY_DIM0]]#1, %[[ARRAY_DIM1]]#1 overflow<nuw> : index
// CHECK: %[[ARRAY_SIZE:.*]] = arith.muli %[[VAL_9]], %[[ARRAY_DIM2]]#1 overflow<nuw> : index
-// CHECK: %[[PAD_DIM0:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_5]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
-// CHECK: %[[PAD_DIM1:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_4]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
-// CHECK: %[[PAD_DIM2:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
-// CHECK: %[[VAL_14:.*]] = arith.muli %[[PAD_DIM0]]#1, %[[PAD_DIM1]]#1 overflow<nuw> : index
-// CHECK: %[[PAD_SIZE:.*]] = arith.muli %[[VAL_14]], %[[PAD_DIM2]]#1 overflow<nuw> : index
// CHECK: %[[VAL_16:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_4]]) : (!fir.ref<!fir.array<2xi32>>, index) -> !fir.ref<i32>
// CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
// CHECK: %[[VAL_18:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_3]]) : (!fir.ref<!fir.array<2xi32>>, index) -> !fir.ref<i32>
@@ -80,15 +75,18 @@ func.func @reshape_with_pad(%arg0: !fir.box<!fir.array<?x?x?xf32>>, %arg1: !fir.
// CHECK: %[[VAL_48:.*]] = fir.load %[[VAL_47]] : !fir.ref<f32>
// CHECK: fir.result %[[VAL_48]] : f32
// CHECK: } else {
+// CHECK: %[[PAD_DIM0:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_5]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK: %[[PAD_DIM1:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_4]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
+// CHECK: %[[PAD_DIM2:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
// CHECK: %[[PAD_LINEAR_INDEX:.*]] = arith.subi %[[LINEAR_INDEX]], %[[ARRAY_SIZE]] overflow<nuw> : index
-// CHECK: %[[PAD_LINEAR_INDEX_MOD:.*]] = arith.remui %[[PAD_LINEAR_INDEX]], %[[PAD_SIZE]] : index
-// CHECK: %[[VAL_51:.*]] = arith.remui %[[PAD_LINEAR_INDEX_MOD]], %[[PAD_DIM0]]#1 : index
-// CHECK: %[[VAL_52:.*]] = arith.divui %[[PAD_LINEAR_INDEX_MOD]], %[[PAD_DIM0]]#1 : index
+// CHECK: %[[VAL_51:.*]] = arith.remui %[[PAD_LINEAR_INDEX]], %[[PAD_DIM0]]#1 : index
+// CHECK: %[[VAL_52:.*]] = arith.divui %[[PAD_LINEAR_INDEX]], %[[PAD_DIM0]]#1 : index
// CHECK: %[[PAD_IDX0:.*]] = arith.addi %[[VAL_51]], %[[VAL_4]] overflow<nuw> : index
// CHECK: %[[VAL_54:.*]] = arith.remui %[[VAL_52]], %[[PAD_DIM1]]#1 : index
// CHECK: %[[VAL_55:.*]] = arith.divui %[[VAL_52]], %[[PAD_DIM1]]#1 : index
// CHECK: %[[PAD_IDX1:.*]] = arith.addi %[[VAL_54]], %[[VAL_4]] overflow<nuw> : index
-// CHECK: %[[PAD_IDX2:.*]] = arith.addi %[[VAL_55]], %[[VAL_4]] overflow<nuw> : index
+// CHECK: %[[VAL_56:.*]] = arith.remui %[[VAL_55]], %[[PAD_DIM2]]#1 : index
+// CHECK: %[[PAD_IDX2:.*]] = arith.addi %[[VAL_56]], %[[VAL_4]] overflow<nuw> : index
// CHECK: %[[VAL_58:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_5]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
// CHECK: %[[VAL_59:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_4]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
// CHECK: %[[VAL_60:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_3]] : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
>From b1adf5de4611ec07dc3d8dc04a03c32206325e6c Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Tue, 28 Jan 2025 12:23:59 -0800
Subject: [PATCH 3/3] Added negative tests and a comment.
---
.../HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp | 5 ++++-
.../HLFIR/simplify-hlfir-intrinsics-reshape.fir | 14 ++++++++++++++
2 files changed, 18 insertions(+), 1 deletion(-)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index bac94c9e5fd1b8..cbed562ef45889 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -969,7 +969,10 @@ class ReshapeAsElementalConversion
"RESHAPE with ORDER argument");
// Verify that the element types of ARRAY, PAD and the result
- // match before doing any transformations.
+ // match before doing any transformations. For example,
+ // the character types of different lengths may appear in the dead
+ // code, and it just does not make sense to inline hlfir.reshape
+ // in this case (a runtime call might have less code size footprint).
hlfir::Entity result = hlfir::Entity{reshape};
hlfir::Entity array = hlfir::Entity{reshape.getArray()};
mlir::Type elementType = array.getFortranElementType();
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir
index 4abf24e1efe1f6..afbd3bcd6d98c7 100644
--- a/flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-reshape.fir
@@ -212,3 +212,17 @@ func.func @reshape_char(%arg0: !fir.box<!fir.array<?x!fir.char<2,?>>>, %arg1: !f
// CHECK: }
// CHECK: return %[[VAL_10]] : !hlfir.expr<?x!fir.char<2,?>>
// CHECK: }
+
+func.func @reshape_negative_result_array_have_different_types(%arg0: !fir.box<!fir.array<?x!fir.char<2,1>>>, %arg1: !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.char<2,2>> {
+ %res = hlfir.reshape %arg0 %arg1 : (!fir.box<!fir.array<?x!fir.char<2,1>>>, !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.char<2,2>>
+ return %res : !hlfir.expr<?x!fir.char<2,2>>
+}
+// CHECK-LABEL: func.func @reshape_negative_result_array_have_different_types(
+// CHECK: hlfir.reshape %{{.*}} %{{.*}} : (!fir.box<!fir.array<?x!fir.char<2>>>, !fir.ref<!fir.array<1xi32>>) -> !hlfir.expr<?x!fir.char<2,2>>
+
+func.func @reshape_negative_array_pad_have_different_types(%arg0: !fir.box<!fir.array<?x!fir.char<2,2>>>, %arg1: !fir.ref<!fir.array<1xi32>>, %arg2: !fir.box<!fir.array<?x!fir.char<2,1>>>) -> !hlfir.expr<?x!fir.char<2,2>> {
+ %res = hlfir.reshape %arg0 %arg1 pad %arg2 : (!fir.box<!fir.array<?x!fir.char<2,2>>>, !fir.ref<!fir.array<1xi32>>, !fir.box<!fir.array<?x!fir.char<2,1>>>) -> !hlfir.expr<?x!fir.char<2,2>>
+ return %res : !hlfir.expr<?x!fir.char<2,2>>
+}
+// CHECK-LABEL: func.func @reshape_negative_array_pad_have_different_types(
+// CHECK: hlfir.reshape %{{.*}} %{{.*}} pad %{{.*}} : (!fir.box<!fir.array<?x!fir.char<2,2>>>, !fir.ref<!fir.array<1xi32>>, !fir.box<!fir.array<?x!fir.char<2>>>) -> !hlfir.expr<?x!fir.char<2,2>>
More information about the flang-commits
mailing list