[flang-commits] [flang] [flang] Inline hlfir.reshape as hlfir.elemental. (PR #124683)
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Thu Jan 30 16:17:34 PST 2025
@@ -951,6 +951,218 @@ class DotProductConversion
+class ReshapeAsElementalConversion
+ : public mlir::OpRewritePattern<hlfir::ReshapeOp> {
+ using mlir::OpRewritePattern<hlfir::ReshapeOp>::OpRewritePattern;
+ llvm::LogicalResult
+ matchAndRewrite(hlfir::ReshapeOp reshape,
+ mlir::PatternRewriter &rewriter) const override {
+ // Do not inline RESHAPE with ORDER yet. The runtime implementation
+ // may be good enough, unless the temporary creation overhead
+ // is high.
+ // TODO: If ORDER is constant, then we can still easily inline.
+ // TODO: If the result's rank is 1, then we can assume ORDER == (/1/).
+ if (reshape.getOrder())
+ return rewriter.notifyMatchFailure(reshape,
+ "RESHAPE with ORDER argument");
+ // Verify that the element types of ARRAY, PAD and the result
+ // match before doing any transformations. For example,
+ // the character types of different lengths may appear in the dead
+ // code, and it just does not make sense to inline hlfir.reshape
+ // in this case (a runtime call might have less code size footprint).
+ hlfir::Entity result = hlfir::Entity{reshape};
+ hlfir::Entity array = hlfir::Entity{reshape.getArray()};
+ mlir::Type elementType = array.getFortranElementType();
+ if (result.getFortranElementType() != elementType)
+ return rewriter.notifyMatchFailure(
+ reshape, "ARRAY and result have different types");
+ mlir::Value pad = reshape.getPad();
+ if (pad && hlfir::getFortranElementType(pad.getType()) != elementType)
+ return rewriter.notifyMatchFailure(reshape,
+ "ARRAY and PAD have different types");
+ // TODO: selecting between ARRAY and PAD of non-trivial element types
+ // requires more work. We have to select between two references
+ // to elements in ARRAY and PAD. This requires conditional
+ // bufferization of the element, if ARRAY/PAD is an expression.
+ if (pad && !fir::isa_trivial(elementType))
+ return rewriter.notifyMatchFailure(reshape,
+ "PAD present with non-trivial type");
+ mlir::Location loc = reshape.getLoc();
+ fir::FirOpBuilder builder{rewriter, reshape.getOperation()};
+ // Assume that all the indices arithmetic does not overflow
+ // the IndexType.
+ builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nuw);
+ llvm::SmallVector<mlir::Value, 1> typeParams;
+ hlfir::genLengthParameters(loc, builder, array, typeParams);
+ // Fetch the extents of ARRAY, PAD and result beforehand.
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayExtents =
+ hlfir::genExtentsVector(loc, builder, array);
+ // If PAD is present, we have to use array size to start taking
+ // elements from the PAD array.
+ mlir::Value arraySize =
+ pad ? computeArraySize(loc, builder, arrayExtents) : nullptr;
+ hlfir::Entity shape = hlfir::Entity{reshape.getShape()};
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> resultExtents;
+ mlir::Type indexType = builder.getIndexType();
+ for (int idx = 0; idx < result.getRank(); ++idx)
+ resultExtents.push_back(hlfir::loadElementAt(
+ loc, builder, shape,
+ builder.createIntegerConstant(loc, indexType, idx + 1)));
+ auto resultShape = builder.create<fir::ShapeOp>(loc, resultExtents);
+ auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange inputIndices) -> hlfir::Entity {
+ mlir::Value linearIndex =
+ computeLinearIndex(loc, builder, resultExtents, inputIndices);
+ fir::IfOp ifOp;
+ if (pad) {
+ // PAD is present. Check if this element comes from the PAD array.
+ mlir::Value isInsideArray = builder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::ult, linearIndex, arraySize);
+ ifOp = builder.create<fir::IfOp>(loc, elementType, isInsideArray,
+ /*withElseRegion=*/true);
+ // In the 'else' block, return an element from the PAD.
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ // PAD is dynamically optional, but we can unconditionally access it
+ // in the 'else' block. If we have to start taking elements from it,
+ // then it must be present in a valid program.
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padExtents =
+ hlfir::genExtentsVector(loc, builder, hlfir::Entity{pad});
+ // Subtract the ARRAY size from the zero-based linear index
+ // to get the zero-based linear index into PAD.
+ mlir::Value padLinearIndex =
+ builder.create<mlir::arith::SubIOp>(loc, linearIndex, arraySize);
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padIndices =
+ delinearizeIndex(loc, builder, padExtents, padLinearIndex,
+ /*wrapAround=*/true);
+ mlir::Value padElement =
+ hlfir::loadElementAt(loc, builder, hlfir::Entity{pad}, padIndices);
+ builder.create<fir::ResultOp>(loc, padElement);
+ // In the 'then' block, return an element from the ARRAY.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ }
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayIndices =
+ delinearizeIndex(loc, builder, arrayExtents, linearIndex,
+ /*wrapAround=*/false);
+ mlir::Value arrayElement =
+ hlfir::loadElementAt(loc, builder, array, arrayIndices);
+ if (ifOp) {
+ builder.create<fir::ResultOp>(loc, arrayElement);
+ builder.setInsertionPointAfter(ifOp);
+ arrayElement = ifOp.getResult(0);
+ }
+ return hlfir::Entity{arrayElement};
+ };
+ hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
+ loc, builder, elementType, resultShape, typeParams, genKernel,
+ /*isUnordered=*/true,
+ /*polymorphicMold=*/result.isPolymorphic() ? array : mlir::Value{},
+ reshape.getResult().getType());
+ assert(elementalOp.getResult().getType() == reshape.getResult().getType());
+ rewriter.replaceOp(reshape, elementalOp);
+ return mlir::success();
+ }
+ /// Compute zero-based linear index given an array extents
+ /// and one-based indices:
+ /// \p extents: [e0, e1, ..., en]
+ /// \p indices: [i0, i1, ..., in]
+ ///
+ /// linear-index :=
+ /// (...((in-1)*e(n-1)+(i(n-1)-1))*e(n-2)+...)*e0+(i0-1)
+ static mlir::Value computeLinearIndex(mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ mlir::ValueRange extents,
+ mlir::ValueRange indices) {
+ std::size_t rank = extents.size();
+ assert(rank = indices.size());
vzakhari wrote:
Thank you! `==` indeed.
More information about the flang-commits
mailing list