[flang-commits] [flang] [flang] Lowering FIR memory ops to MemRef dialect (PR #173507)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Sat Dec 27 18:41:30 PST 2025
================
@@ -0,0 +1,1156 @@
+//===-- FIRToMemRef.cpp - Convert FIR loads and stores to MemRef ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass lowers FIR dialect memory operations to the MemRef dialect.
+// In particular it:
+//
+// - Rewrites `fir.alloca` to `memref.alloca`.
+//
+// - Rewrites `fir.load` / `fir.store` to `memref.load` / `memref.store`.
+//
+// - Allows FIR and MemRef to coexist by introducing `fir.convert` at
+// memory-use sites. Memory operations (`memref.load`, `memref.store`,
+// `memref.reinterpret_cast`, etc.) see MemRef-typed values, while the
+// original FIR-typed values remain available for non-memory uses. For
+// example:
+//
+// %fir_ref = ... : !fir.ref<!fir.array<...>>
+// %memref = fir.convert %fir_ref
+// : !fir.ref<!fir.array<...>> -> memref<...>
+// %val = memref.load %memref[...] : memref<...>
+// fir.call @callee(%fir_ref) : (!fir.ref<!fir.array<...>>) -> ()
+//
+// Here the MemRef-typed value is used for `memref.load`, while the
+// original FIR-typed value is preserved for `fir.call`.
+//
+// - Computes shapes, strides, and indices as needed for slices and shifts
+// and emits `memref.reinterpret_cast` when dynamic layout is required
+// (TODO: use memref.cast instead).
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Transforms/FIRToMemRefTypeConverter.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+
+#include "flang/Optimizer/Builder/CUFCommon.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIROpsSupport.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Dialect/Support/FIRContext.h"
+#include "flang/Optimizer/Dialect/Support/KindMapping.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#define DEBUG_TYPE "fir-to-memref"
+
+using namespace mlir;
+
+namespace fir {
+
+#define GEN_PASS_DEF_FIRTOMEMREF
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+
+static bool isMarshalLike(Operation *op) {
+ if (!op)
+ return false;
+
+ auto convert = dyn_cast<fir::ConvertOp>(op);
+ if (!convert)
+ return false;
+
+ bool resIsMemRef = isa<MemRefType>(convert.getType());
+ bool argIsMemRef = isa<MemRefType>(convert.getValue().getType());
+
+ assert(!(resIsMemRef && argIsMemRef) &&
+ "unexpected fir.convert memref -> memref in isMarshalLike");
+
+ return resIsMemRef || argIsMemRef;
+}
+
+using MemRefInfo = FailureOr<std::pair<Value, SmallVector<Value>>>;
+
+static llvm::cl::opt<bool> enableFIRConvertOptimizations(
+ "enable-fir-convert-opts",
+ llvm::cl::desc("enable emilinating redundant fir.convert in FIR-to-MemRef"),
+ llvm::cl::init(false), llvm::cl::Hidden);
+
+class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
+public:
+ void runOnOperation() override;
+
+private:
+ llvm::SmallSetVector<Operation *, 32> eraseOps;
+
+ DominanceInfo *domInfo = nullptr;
+
+ void rewriteAlloca(fir::AllocaOp, PatternRewriter &,
+ FIRToMemRefTypeConverter &);
+
+ void rewriteLoadOp(fir::LoadOp, PatternRewriter &,
+ FIRToMemRefTypeConverter &);
+
+ void rewriteStoreOp(fir::StoreOp, PatternRewriter &,
+ FIRToMemRefTypeConverter &);
+
+ MemRefInfo getMemRefInfo(Value, PatternRewriter &, FIRToMemRefTypeConverter &,
+ Operation *);
+
+ MemRefInfo convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp,
+ PatternRewriter &, FIRToMemRefTypeConverter &);
+
+ void replaceFIRMemrefs(Value, Value, PatternRewriter &) const;
+
+ FailureOr<Value> getFIRConvert(Operation *memOp, Operation *memref,
+ PatternRewriter &, FIRToMemRefTypeConverter &);
+
+ FailureOr<SmallVector<Value>> getMemrefIndices(fir::ArrayCoorOp, Operation *,
+ PatternRewriter &, Value,
+ Value) const;
+
+ bool memrefIsOptional(Operation *) const;
+
+ Value canonicalizeIndex(Value, PatternRewriter &) const;
+
+ template <typename OpTy>
+ void getShapeFrom(OpTy op, SmallVector<Value> &shapeVec,
+ SmallVector<Value> &shiftVec,
+ SmallVector<Value> &sliceVec) const;
+
+ void populateShapeAndShift(SmallVectorImpl<Value> &shapeVec,
+ SmallVectorImpl<Value> &shiftVec,
+ fir::ShapeShiftOp shift) const;
+
+ void populateShift(SmallVectorImpl<Value> &vec, fir::ShiftOp shift) const;
+
+ void populateShape(SmallVectorImpl<Value> &vec, fir::ShapeOp shape) const;
+
+ unsigned getRankFromEmbox(fir::EmboxOp embox) const {
+ auto memrefType = embox.getMemref().getType();
+ Type unwrappedType = fir::unwrapRefType(memrefType);
+ if (auto seqType = dyn_cast<fir::SequenceType>(unwrappedType))
+ return seqType.getDimension();
+ return 0;
+ }
+
+ bool isCompilerGeneratedAlloca(Operation *op) const;
+
+ void copyAttribute(Operation *from, Operation *to,
+ llvm::StringRef name) const;
+
+ Type getBaseType(Type type, bool complexBaseTypes = false) const;
+
+ bool memrefIsDeviceData(Operation *memref) const;
+
+ bool isMarshalLikeOp(Operation *op) const;
+
+ mlir::Attribute findCudaDataAttr(Value val) const;
+};
+
+void FIRToMemRef::populateShapeAndShift(SmallVectorImpl<Value> &shapeVec,
+ SmallVectorImpl<Value> &shiftVec,
+ fir::ShapeShiftOp shift) const {
+ for (auto i = shift.getPairs().begin(), endIter = shift.getPairs().end();
+ i != endIter;) {
+ shiftVec.push_back(*i++);
+ shapeVec.push_back(*i++);
+ }
+}
+
+bool FIRToMemRef::isCompilerGeneratedAlloca(Operation *op) const {
+ if (!isa<fir::AllocaOp, memref::AllocaOp>(op))
+ llvm_unreachable("expected alloca op");
+
+ return !op->getAttr("bindc_name") && !op->getAttr("uniq_name");
+}
+
+void FIRToMemRef::copyAttribute(Operation *from, Operation *to,
+ llvm::StringRef name) const {
+ if (auto value = from->getAttr(name))
+ to->setAttr(name, value);
+}
+
+Type FIRToMemRef::getBaseType(Type type, bool complexBaseTypes) const {
+ if (fir::isa_fir_type(type)) {
+ type = fir::unwrapAllRefAndSeqType(type);
+ type = fir::unwrapSeqOrBoxedSeqType(type);
+ } else if (auto memrefTy = dyn_cast<MemRefType>(type)) {
+ type = memrefTy.getElementType();
+ }
+
+ if (!complexBaseTypes) {
+ if (auto complexTy = dyn_cast<ComplexType>(type))
+ type = complexTy.getElementType();
+ }
+ return type;
+}
+
+bool FIRToMemRef::memrefIsDeviceData(Operation *memref) const {
+ if (isa<ACC_DATA_ENTRY_OPS>(memref))
+ return true;
+
+ if (auto cudaAttr = cuf::getDataAttr(memref)) {
+ auto attrValue = cudaAttr.getValue();
+ return attrValue == cuf::DataAttribute::Device ||
+ attrValue == cuf::DataAttribute::Managed ||
+ attrValue == cuf::DataAttribute::Constant ||
+ attrValue == cuf::DataAttribute::Shared ||
+ attrValue == cuf::DataAttribute::Unified;
+ }
+ return false;
+}
+
+bool FIRToMemRef::isMarshalLikeOp(Operation *op) const {
+ if (!op)
+ return false;
+
+ auto convert = dyn_cast<fir::ConvertOp>(op);
+ if (convert) {
+ bool resIsMemRef = isa<MemRefType>(convert.getType());
+ bool argIsMemRef = isa<MemRefType>(convert.getValue().getType());
+ assert(!(resIsMemRef && argIsMemRef) &&
+ "unexpected fir.convert memref -> memref in isMarshalLikeOp");
+ }
+
+ auto isaPolymorphicConversion = [](fir::ConvertOp c) {
+ bool retVal{false};
+ if (auto fromBoxTy{dyn_cast<fir::ClassType>(
+ fir::unwrapRefType(c.getValue().getType()))}) {
+ if (auto toBoxTy{
+ dyn_cast<fir::BaseBoxType>(fir::unwrapRefType(c.getType()))}) {
+ auto fromEleTy{fir::unwrapAllRefAndSeqType(fromBoxTy.getEleTy())};
+ auto toEleTy{fir::unwrapAllRefAndSeqType(toBoxTy.getEleTy())};
+ if (fromEleTy != toEleTy)
+ retVal = true;
+ }
+ }
+ return retVal;
+ };
+
+ return convert && !isaPolymorphicConversion(convert) &&
+ (isa<MemRefType>(convert.getType()) ||
+ isa<MemRefType>(convert.getValue().getType()));
+}
+
+mlir::Attribute FIRToMemRef::findCudaDataAttr(Value val) const {
+ Value currentVal = val;
+ llvm::SmallPtrSet<Operation *, 8> visited;
+
+ while (currentVal) {
+ auto defOp = currentVal.getDefiningOp();
+ if (!defOp || !visited.insert(defOp).second)
+ break;
+
+ if (auto cudaAttr = cuf::getDataAttr(defOp))
+ return cudaAttr;
+
+ if (auto reboxOp = dyn_cast<fir::ReboxOp>(defOp)) {
+ currentVal = reboxOp.getBox();
+ } else if (auto emboxOp = dyn_cast<fir::EmboxOp>(defOp)) {
+ currentVal = emboxOp.getMemref();
+ } else if (auto declareOp = dyn_cast<fir::DeclareOp>(defOp)) {
+ currentVal = declareOp.getMemref();
+ } else {
+ break;
+ }
+ }
+ return nullptr;
+}
+
+void FIRToMemRef::populateShift(SmallVectorImpl<Value> &vec,
+ fir::ShiftOp shift) const {
+ vec.append(shift.getOrigins().begin(), shift.getOrigins().end());
+}
+
+void FIRToMemRef::populateShape(SmallVectorImpl<Value> &vec,
+ fir::ShapeOp shape) const {
+ vec.append(shape.getExtents().begin(), shape.getExtents().end());
+}
+
+template <typename OpTy>
+void FIRToMemRef::getShapeFrom(OpTy op, SmallVector<Value> &shapeVec,
+ SmallVector<Value> &shiftVec,
+ SmallVector<Value> &sliceVec) const {
+ if constexpr (std::is_same_v<OpTy, fir::ArrayCoorOp> ||
+ std::is_same_v<OpTy, fir::ReboxOp> ||
+ std::is_same_v<OpTy, fir::EmboxOp>) {
+ Value shapeVal = op.getShape();
+
+ if (shapeVal) {
+ Operation *shapeValOp = shapeVal.getDefiningOp();
+
+ if (auto shapeOp = dyn_cast<fir::ShapeOp>(shapeValOp)) {
+ populateShape(shapeVec, shapeOp);
+ } else if (auto shapeShiftOp = dyn_cast<fir::ShapeShiftOp>(shapeValOp)) {
+ populateShapeAndShift(shapeVec, shiftVec, shapeShiftOp);
+ } else if (auto shiftOp = dyn_cast<fir::ShiftOp>(shapeValOp)) {
+ populateShift(shiftVec, shiftOp);
+ }
+ }
+
+ Value sliceVal = op.getSlice();
+ if (sliceVal) {
+ if (auto sliceOp = sliceVal.getDefiningOp<fir::SliceOp>()) {
+ auto triples = sliceOp.getTriples();
+ sliceVec.append(triples.begin(), triples.end());
+ }
+ }
+ }
+}
+
+void FIRToMemRef::rewriteAlloca(fir::AllocaOp firAlloca,
+ PatternRewriter &rewriter,
+ FIRToMemRefTypeConverter &typeConverter) {
+ if (!typeConverter.convertibleType(firAlloca.getInType()))
+ return;
+
+ if (typeConverter.isEmptyArray(firAlloca.getType()))
+ return;
+
+ rewriter.setInsertionPointAfter(firAlloca);
+
+ Type type = firAlloca.getType();
+ MemRefType memrefTy = typeConverter.convertMemrefType(type);
+
+ Location loc = firAlloca.getLoc();
+
+ SmallVector<Value> sizes = firAlloca.getOperands();
+ std::reverse(sizes.begin(), sizes.end());
+
+ auto alloca = memref::AllocaOp::create(rewriter, loc, memrefTy, sizes);
+ copyAttribute(firAlloca, alloca, firAlloca.getBindcNameAttrName());
+ copyAttribute(firAlloca, alloca, firAlloca.getUniqNameAttrName());
+ copyAttribute(firAlloca, alloca, cuf::getDataAttrName());
+
+ auto convert = fir::ConvertOp::create(rewriter, loc, type, alloca);
+
+ rewriter.replaceOp(firAlloca, convert);
+
+ if (isCompilerGeneratedAlloca(alloca)) {
+ for (Operation *userOp : convert->getUsers()) {
+ if (auto declareOp = dyn_cast<fir::DeclareOp>(userOp)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "FIRToMemRef: removing declare for compiler temp:\n";
+ declareOp->dump());
+ declareOp->replaceAllUsesWith(convert);
+ eraseOps.insert(userOp);
+ }
+ }
+ }
+}
+
+bool FIRToMemRef::memrefIsOptional(Operation *op) const {
+ if (auto declare = dyn_cast<fir::DeclareOp>(op)) {
+ Value operand = declare.getMemref();
+
+ if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
+ if (auto func =
+ dyn_cast<func::FuncOp>((blockArg.getOwner())->getParentOp())) {
+ if (func.getArgAttr(blockArg.getArgNumber(),
+ fir::getOptionalAttrName()))
+ return true;
+ }
+ }
+
+ Operation *operandOp = operand.getDefiningOp();
+ if (operandOp && isa<fir::AbsentOp>(operandOp))
+ return true;
+ }
+
+ for (mlir::Value result : op->getResults()) {
+ for (mlir::Operation *userOp : result.getUsers()) {
+ if (isa<fir::IsPresentOp>(userOp))
+ return true;
+ }
+ }
+
+ return false;
+}
+
+static Value castTypeToIndexType(Value originalValue,
+ PatternRewriter &rewriter) {
+ if (originalValue.getType().isIndex())
+ return originalValue;
+
+ Type indexType = rewriter.getIndexType();
+ return arith::IndexCastOp::create(rewriter, originalValue.getLoc(), indexType,
+ originalValue);
+}
+
+FailureOr<SmallVector<Value>>
+FIRToMemRef::getMemrefIndices(fir::ArrayCoorOp arrayCoorOp, Operation *memref,
+ PatternRewriter &rewriter, Value converted,
+ Value one) const {
+ IndexType indexTy = rewriter.getIndexType();
+ SmallVector<Value> indices;
+ Location loc = arrayCoorOp->getLoc();
+ SmallVector<Value> shiftVec, shapeVec, sliceVec;
+ int rank = arrayCoorOp.getIndices().size();
+ getShapeFrom<fir::ArrayCoorOp>(arrayCoorOp, shapeVec, shiftVec, sliceVec);
+
+ if (auto embox = dyn_cast_or_null<fir::EmboxOp>(memref)) {
+ getShapeFrom<fir::EmboxOp>(embox, shapeVec, shiftVec, sliceVec);
+ rank = getRankFromEmbox(embox);
+ }
+
+ SmallVector<Value> sliceLbs, sliceStrides;
+ for (size_t i = 0; i < sliceVec.size(); i += 3) {
+ sliceLbs.push_back(castTypeToIndexType(sliceVec[i], rewriter));
+ sliceStrides.push_back(castTypeToIndexType(sliceVec[i + 2], rewriter));
+ }
+
+ const bool isShifted = !shiftVec.empty();
+ const bool isSliced = !sliceVec.empty();
+
+ ValueRange idxs = arrayCoorOp.getIndices();
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
+ SmallVector<bool> filledPositions(rank, false);
+ for (int i = 0; i < rank; ++i) {
+ Value step = isSliced ? sliceStrides[i] : one;
+ Operation *stepOp = step.getDefiningOp();
+ if (stepOp && mlir::isa_and_nonnull<fir::UndefOp>(stepOp)) {
+ Value shift = isShifted ? shiftVec[i] : one;
+ Value sliceLb = isSliced ? sliceLbs[i] : shift;
+ Value offset = arith::SubIOp::create(rewriter, loc, sliceLb, shift);
+ indices.push_back(offset);
+ filledPositions[i] = true;
+ } else {
+ indices.push_back(zero);
+ }
+ }
+
+ int arrayCoorIdx = 0;
+ for (int i = 0; i < rank; ++i) {
+ if (filledPositions[i])
+ continue;
+
+ assert((unsigned int)arrayCoorIdx < idxs.size() &&
+ "empty dimension should be eliminated\n");
+ Value index = canonicalizeIndex(idxs[arrayCoorIdx], rewriter);
+ Type cTy = index.getType();
+ if (!llvm::isa<IndexType>(cTy)) {
+ assert(cTy.isSignlessInteger() && "expected signless integer type");
+ index = arith::IndexCastOp::create(rewriter, loc, indexTy, index);
+ }
+
+ Value shift = isShifted ? shiftVec[i] : one;
+ Value stride = isSliced ? sliceStrides[i] : one;
+ Value sliceLb = isSliced ? sliceLbs[i] : shift;
+
+ Value oneIdx = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ Value indexAdjustment = isSliced ? oneIdx : sliceLb;
+ Value delta = arith::SubIOp::create(rewriter, loc, index, indexAdjustment);
+
+ Value scaled = arith::MulIOp::create(rewriter, loc, delta, stride);
+
+ Value offset = arith::SubIOp::create(rewriter, loc, sliceLb, shift);
+
+ Value finalIndex = arith::AddIOp::create(rewriter, loc, scaled, offset);
+
+ indices[i] = finalIndex;
+ arrayCoorIdx++;
+ }
+
+ std::reverse(indices.begin(), indices.end());
+
+ return indices;
+}
+
+MemRefInfo
+FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
+ PatternRewriter &rewriter,
+ FIRToMemRefTypeConverter &typeConverter) {
+ IndexType indexTy = rewriter.getIndexType();
+ Value firMemref = arrayCoorOp.getMemref();
+ if (!typeConverter.convertibleMemrefType(firMemref.getType()))
+ return failure();
+
+ if (typeConverter.isEmptyArray(firMemref.getType()))
+ return failure();
+
+ if (auto blockArg = dyn_cast<BlockArgument>(firMemref)) {
+ Value elemRef = arrayCoorOp.getResult();
+ rewriter.setInsertionPointAfter(arrayCoorOp);
+ Location loc = arrayCoorOp->getLoc();
+ Type elemMemrefTy = typeConverter.convertMemrefType(elemRef.getType());
+ Value converted =
+ fir::ConvertOp::create(rewriter, loc, elemMemrefTy, elemRef);
+ SmallVector<Value> indices;
+ return std::pair{converted, indices};
+ }
+
+ Operation *memref = firMemref.getDefiningOp();
+
+ FailureOr<Value> converted;
+ if (enableFIRConvertOptimizations && isMarshalLike(memref) &&
+ !fir::isa_fir_type(firMemref.getType())) {
+ converted = firMemref;
+ rewriter.setInsertionPoint(arrayCoorOp);
+ } else {
+ Operation *arrayCoorOperation = arrayCoorOp.getOperation();
+ if (memrefIsOptional(memref)) {
+ rewriter.setInsertionPoint(arrayCoorOp);
+ auto ifOp = arrayCoorOperation->getParentOfType<scf::IfOp>();
+ if (ifOp) {
+ Operation *condition = ifOp.getCondition().getDefiningOp();
+ if (condition && isa<fir::IsPresentOp>(condition)) {
+ if (condition->getOperand(0) == firMemref) {
+ if (arrayCoorOperation->getParentRegion() ==
+ &ifOp.getThenRegion()) {
+ rewriter.setInsertionPointToStart(
+ &(ifOp.getThenRegion().front()));
+ } else if (arrayCoorOperation->getParentRegion() ==
+ &ifOp.getElseRegion()) {
+ rewriter.setInsertionPointToStart(
+ &(ifOp.getElseRegion().front()));
+ }
+ }
+ }
+ }
+ }
+
+ rewriter.setInsertionPoint(arrayCoorOp);
+ converted = getFIRConvert(memOp, memref, rewriter, typeConverter);
+ if (failed(converted))
+ return failure();
+
+ rewriter.setInsertionPointAfter(arrayCoorOp);
+ }
+
+ Location loc = arrayCoorOp->getLoc();
+ Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ FailureOr<SmallVector<Value>> failureOrIndices =
+ getMemrefIndices(arrayCoorOp, memref, rewriter, *converted, one);
+ if (failed(failureOrIndices))
+ return failure();
+ SmallVector<Value> indices = *failureOrIndices;
+
+ if (converted == firMemref)
+ return std::pair{*converted, indices};
+
+ Value convertedVal = *converted;
+ MemRefType memRefTy = dyn_cast<MemRefType>(convertedVal.getType());
+
+ bool isRebox = firMemref.getDefiningOp<fir::ReboxOp>() != nullptr;
+
+ if (memRefTy.hasStaticShape() && !isRebox)
+ return std::pair{*converted, indices};
+
+ unsigned rank = arrayCoorOp.getIndices().size();
+
+ if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>()) {
+ rank = getRankFromEmbox(embox);
+ }
+
+ SmallVector<Value> sizes;
+ sizes.reserve(rank);
+ SmallVector<Value> strides;
+ strides.reserve(rank);
+
+ SmallVector<Value> shapeVec, shiftVec, sliceVec;
+ getShapeFrom<fir::ArrayCoorOp>(arrayCoorOp, shapeVec, shiftVec, sliceVec);
+
+ Value box = firMemref;
+ if (!isa<BlockArgument>(firMemref)) {
+ if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>()) {
+ getShapeFrom<fir::EmboxOp>(embox, shapeVec, shiftVec, sliceVec);
+ } else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>()) {
+ getShapeFrom<fir::ReboxOp>(rebox, shapeVec, shiftVec, sliceVec);
+ }
+ }
+
+ if (shapeVec.empty()) {
+ auto boxElementSize =
+ fir::BoxEleSizeOp::create(rewriter, loc, indexTy, box);
+
+ for (unsigned i = 0; i < rank; ++i) {
+ Value dim = arith::ConstantIndexOp::create(rewriter, loc, rank - i - 1);
+ auto boxDims = fir::BoxDimsOp::create(rewriter, loc, indexTy, indexTy,
+ indexTy, box, dim);
+
+ Value extent = boxDims->getResult(1);
+ sizes.push_back(castTypeToIndexType(extent, rewriter));
+
+ Value byteStride = boxDims->getResult(2);
+ Value div =
+ arith::DivSIOp::create(rewriter, loc, byteStride, boxElementSize);
+ strides.push_back(castTypeToIndexType(div, rewriter));
+ }
+
+ } else {
+ Value oneIdx =
+ arith::ConstantIndexOp::create(rewriter, arrayCoorOp->getLoc(), 1);
+ for (unsigned i = rank - 1; i > 0; --i) {
+ Value size = shapeVec[i];
+ sizes.push_back(castTypeToIndexType(size, rewriter));
+
+ Value stride = shapeVec[0];
+ for (unsigned j = 1; j <= i - 1; ++j) {
+ stride = arith::MulIOp::create(rewriter, loc, shapeVec[j], stride);
+ }
+ strides.push_back(castTypeToIndexType(stride, rewriter));
+ }
+
+ sizes.push_back(castTypeToIndexType(shapeVec[0], rewriter));
+ strides.push_back(oneIdx);
+ }
+
+ assert(strides.size() == sizes.size() && sizes.size() == rank);
+
+ int64_t dynamicOffset = ShapedType::kDynamic;
+ SmallVector<int64_t> dynamicStrides(rank, ShapedType::kDynamic);
+ auto stridedLayout = StridedLayoutAttr::get(convertedVal.getContext(),
+ dynamicOffset, dynamicStrides);
+
+ SmallVector<int64_t> dynamicShape(rank, ShapedType::kDynamic);
+ memRefTy =
+ MemRefType::get(dynamicShape, memRefTy.getElementType(), stridedLayout);
+
+ Value offset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
+ auto reinterpret = memref::ReinterpretCastOp::create(
+ rewriter, loc, memRefTy, *converted, offset, sizes, strides);
+
+ Value result = reinterpret->getResult(0);
+ return std::pair{result, indices};
+}
+
+FailureOr<Value>
+FIRToMemRef::getFIRConvert(Operation *memOp, Operation *op,
+ PatternRewriter &rewriter,
+ FIRToMemRefTypeConverter &typeConverter) {
+ if (enableFIRConvertOptimizations && !op->hasOneUse() &&
+ !memrefIsOptional(op)) {
+ for (Operation *userOp : op->getUsers()) {
+ if (auto convertOp = dyn_cast<fir::ConvertOp>(userOp)) {
+ Value converted = convertOp.getResult();
+ if (!isa<MemRefType>(converted.getType()))
+ continue;
+
+ if (userOp->getParentOp() == memOp->getParentOp() &&
+ domInfo->dominates(userOp, memOp)) {
+ return converted;
+ }
+ }
+ }
+ }
+
+ assert(op->getNumResults() == 1 && "expecting one result");
+
+ Value basePtr = op->getResult(0);
+
+ MemRefType memrefTy = typeConverter.convertMemrefType(basePtr.getType());
+ Type baseTy = memrefTy.getElementType();
+
+ if (fir::isa_std_type(baseTy) && memrefTy.getRank() == 0) {
+ if (auto convertOp = basePtr.getDefiningOp<fir::ConvertOp>()) {
+ Value input = convertOp.getOperand();
+ if (auto alloca = input.getDefiningOp<memref::AllocaOp>()) {
+ assert(alloca.getType() == memrefTy && "expected same types");
+ if (isCompilerGeneratedAlloca(alloca)) {
+ return alloca.getResult();
+ }
+ }
+ }
+ }
+
+ const Location loc = op->getLoc();
+
+ if (isa<fir::BoxType>(basePtr.getType())) {
+ Operation *baseOp = basePtr.getDefiningOp();
+ auto boxAddrOp = fir::BoxAddrOp::create(rewriter, loc, basePtr);
+
+ if (auto cudaAttr = findCudaDataAttr(basePtr)) {
+ boxAddrOp->setAttr(cuf::getDataAttrName(), cudaAttr);
+ }
+
+ basePtr = boxAddrOp;
+ memrefTy = typeConverter.convertMemrefType(basePtr.getType());
+
+ if (baseOp) {
+ auto sameBaseBoxTypes = [&](Type baseType, Type memrefType) -> bool {
+ Type emboxBaseTy = getBaseType(baseType, true);
+ Type emboxMemrefTy = getBaseType(memrefType, true);
+ return emboxBaseTy == emboxMemrefTy;
+ };
+
+ if (auto embox = dyn_cast_or_null<fir::EmboxOp>(baseOp)) {
+ if (!sameBaseBoxTypes(embox.getType(), embox.getMemref().getType())) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "FIRToMemRef: embox base type and memref type are not "
+ "the same, bailing out of conversion\n");
+ return failure();
+ }
+ if (embox.getSlice() &&
+ embox.getSlice().getDefiningOp<fir::SliceOp>()) {
+ Type originalType = embox.getMemref().getType();
+ basePtr = embox.getMemref();
+
+ if (typeConverter.convertibleMemrefType(originalType)) {
+ auto convertedMemrefTy =
+ typeConverter.convertMemrefType(originalType);
+ memrefTy = convertedMemrefTy;
+ } else {
+ return failure();
+ }
+ }
+ }
+
+ if (auto rebox = dyn_cast<fir::ReboxOp>(baseOp)) {
+ if (!sameBaseBoxTypes(rebox.getType(), rebox.getBox().getType())) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "FIRToMemRef: rebox base type and box type are not the "
+ "same, bailing out of conversion\n");
+ return failure();
+ }
+ Type originalType = rebox.getBox().getType();
+ if (auto boxTy = dyn_cast<fir::BoxType>(originalType)) {
+ originalType = boxTy.getElementType();
+ }
+ if (!typeConverter.convertibleMemrefType(originalType)) {
+ return failure();
+ } else {
+ auto convertedMemrefTy =
+ typeConverter.convertMemrefType(originalType);
+ memrefTy = convertedMemrefTy;
+ }
+ }
+ }
+ }
+
+ auto convert = fir::ConvertOp::create(rewriter, loc, memrefTy, basePtr);
+ return convert->getResult(0);
+}
+
+Value FIRToMemRef::canonicalizeIndex(Value index,
+ PatternRewriter &rewriter) const {
+ if (auto blockArg = dyn_cast<BlockArgument>(index))
+ return index;
+
+ Operation *op = index.getDefiningOp();
+
+ if (auto constant = dyn_cast<arith::ConstantIntOp>(op)) {
+ if (!constant.getType().isIndex()) {
+ Value v = arith::ConstantIndexOp::create(rewriter, op->getLoc(),
+ constant.value());
+ return v;
+ }
+ return constant;
+ }
+
+ if (auto extsi = dyn_cast<arith::ExtSIOp>(op)) {
+ Value operand = extsi.getOperand();
+ if (auto indexCast = operand.getDefiningOp<arith::IndexCastOp>()) {
+ Value v = indexCast.getOperand();
+ return v;
+ }
+ return canonicalizeIndex(operand, rewriter);
+ }
+
+ if (auto add = dyn_cast<arith::AddIOp>(op)) {
+ Value lhs = canonicalizeIndex(add.getLhs(), rewriter);
+ Value rhs = canonicalizeIndex(add.getRhs(), rewriter);
+ if (lhs.getType() == rhs.getType()) {
+ return arith::AddIOp::create(rewriter, op->getLoc(), lhs, rhs);
+ }
+ }
+ return index;
+}
+
+MemRefInfo FIRToMemRef::getMemRefInfo(Value firMemref,
+ PatternRewriter &rewriter,
+ FIRToMemRefTypeConverter &typeConverter,
+ Operation *memOp) {
+ Operation *memrefOp = firMemref.getDefiningOp();
+ if (!memrefOp) {
+ if (auto blockArg = dyn_cast<BlockArgument>(firMemref)) {
+ rewriter.setInsertionPoint(memOp);
+ Type memrefTy = typeConverter.convertMemrefType(blockArg.getType());
+ if (auto mt = dyn_cast<MemRefType>(memrefTy)) {
+ if (auto inner = llvm::dyn_cast<MemRefType>(mt.getElementType()))
+ memrefTy = inner;
+ }
+ Value converted = fir::ConvertOp::create(rewriter, blockArg.getLoc(),
+ memrefTy, blockArg);
+ SmallVector<Value> indices;
+ return std::pair{converted, indices};
+ }
+ llvm_unreachable(
+ "FIRToMemRef: expected defining op or block argument for FIR memref");
+ }
+
+ if (auto arrayCoorOp = dyn_cast<fir::ArrayCoorOp>(memrefOp)) {
+ MemRefInfo memrefInfo =
+ convertArrayCoorOp(memOp, arrayCoorOp, rewriter, typeConverter);
+ if (succeeded(memrefInfo)) {
+ for (auto user : memrefOp->getUsers()) {
+ if (!isa<fir::LoadOp, fir::StoreOp>(user)) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "FIRToMemRef: array memref used by unsupported op:\n";
+ firMemref.dump(); user->dump());
+ return memrefInfo;
+ }
+ }
+ eraseOps.insert(memrefOp);
+ }
+ return memrefInfo;
+ }
+
+ rewriter.setInsertionPoint(memOp);
+
+ if (isMarshalLike(memrefOp)) {
+ FailureOr<Value> converted =
+ getFIRConvert(memOp, memrefOp, rewriter, typeConverter);
+ if (failed(converted)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "FIRToMemRef: expected FIR memref in convert, bailing "
+ "out:\n";
+ firMemref.dump());
+ return failure();
+ }
+ SmallVector<Value> indices;
+ return std::pair{*converted, indices};
+ }
+
+ if (auto declareOp = dyn_cast<fir::DeclareOp>(memrefOp)) {
+ if (memrefIsOptional(memrefOp)) {
+ rewriter.setInsertionPoint(memOp);
+ auto ifOp = memOp->getParentOfType<scf::IfOp>();
+ if (ifOp) {
+ Operation *condition = ifOp.getCondition().getDefiningOp();
+ if (condition && isa<fir::IsPresentOp>(condition)) {
+ if (condition->getOperand(0) == declareOp) {
+ if (memOp->getParentRegion() == &ifOp.getThenRegion()) {
+ rewriter.setInsertionPointToStart(
+ &(ifOp.getThenRegion().front()));
+ } else if (memOp->getParentRegion() == &ifOp.getElseRegion()) {
+ rewriter.setInsertionPointToStart(
+ &(ifOp.getElseRegion().front()));
+ }
+ }
+ }
+ }
+ }
+
+ FailureOr<Value> converted =
+ getFIRConvert(memOp, declareOp, rewriter, typeConverter);
+ if (failed(converted)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "FIRToMemRef: unable to create convert for scalar "
+ "memref:\n";
+ firMemref.dump());
+ return failure();
+ }
+ SmallVector<Value> indices;
+ return std::pair{*converted, indices};
+ }
+
+ if (auto coordinateOp = dyn_cast<fir::CoordinateOp>(memrefOp)) {
+ FailureOr<Value> converted =
+ getFIRConvert(memOp, coordinateOp, rewriter, typeConverter);
+ if (failed(converted)) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "FIRToMemRef: unable to create convert for derived-type "
+ "memref:\n";
+ firMemref.dump());
+ return failure();
+ }
+ SmallVector<Value> indices;
+ return std::pair{*converted, indices};
+ }
+
+ if (auto convertOp = dyn_cast<fir::ConvertOp>(memrefOp)) {
+ Type fromTy = convertOp->getOperand(0).getType();
+ Type toTy = firMemref.getType();
+ if (isa<fir::ReferenceType>(fromTy) && isa<fir::ReferenceType>(toTy)) {
+ FailureOr<Value> converted =
+ getFIRConvert(memOp, convertOp, rewriter, typeConverter);
+ if (failed(converted)) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "FIRToMemRef: unable to create convert for conversion "
+ "op:\n";
+ firMemref.dump());
+ return failure();
+ }
+ SmallVector<Value> indices;
+ return std::pair{*converted, indices};
+ }
+ }
+
+ if (auto boxAddrOp = dyn_cast<fir::BoxAddrOp>(memrefOp)) {
+ FailureOr<Value> converted =
+ getFIRConvert(memOp, boxAddrOp, rewriter, typeConverter);
+ if (failed(converted)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "FIRToMemRef: unable to create convert for box_addr "
+ "op:\n";
+ firMemref.dump());
+ return failure();
+ }
+ SmallVector<Value> indices;
+ return std::pair{*converted, indices};
+ }
+
+ if (memrefIsDeviceData(memrefOp)) {
+ FailureOr<Value> converted =
+ getFIRConvert(memOp, memrefOp, rewriter, typeConverter);
+ if (failed(converted))
+ return failure();
+ SmallVector<Value> indices;
+ return std::pair{*converted, indices};
+ }
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "FIRToMemRef: unable to create convert for memref value:\n";
+ firMemref.dump());
+
+ return failure();
+}
+
+void FIRToMemRef::replaceFIRMemrefs(Value firMemref, Value converted,
+ PatternRewriter &rewriter) const {
+ Operation *op = firMemref.getDefiningOp();
+ if (op && (isa<fir::ArrayCoorOp>(op) || isMarshalLike(op)))
+ return;
+
+ SmallPtrSet<Operation *, 4> worklist;
+ for (auto user : firMemref.getUsers()) {
+ if (isMarshalLike(user) || isa<fir::LoadOp, fir::StoreOp>(user))
+ continue;
+ if (!domInfo->dominates(converted, user))
+ continue;
+
+ if (!(isa<omp::AtomicCaptureOp>(user->getParentOp()) ||
+ isa<acc::AtomicCaptureOp>(user->getParentOp()))) {
+ worklist.insert(user);
+ }
+ }
+
+ Type ty = firMemref.getType();
+
+ for (auto op : worklist) {
+ rewriter.setInsertionPoint(op);
+ Location loc = op->getLoc();
+ Value replaceConvert = fir::ConvertOp::create(rewriter, loc, ty, converted);
+ op->replaceUsesOfWith(firMemref, replaceConvert);
+ }
+
+ worklist.clear();
+
+ for (auto user : firMemref.getUsers()) {
+ if (isMarshalLike(user) || isa<fir::LoadOp, fir::StoreOp>(user))
+ continue;
+ if (isa<omp::AtomicCaptureOp>(user->getParentOp()) ||
+ isa<acc::AtomicCaptureOp>(user->getParentOp())) {
+ if (domInfo->dominates(converted, user))
+ worklist.insert(user);
+ }
+ }
+
+ if (worklist.empty())
+ return;
+
+ while (!worklist.empty()) {
+ Operation *parentOp = (*worklist.begin())->getParentOp();
+
+ Value replaceConvert;
+ SmallVector<Operation *> erase;
+ for (auto op : worklist) {
+ if (op->getParentOp() != parentOp)
+ continue;
+ if (!replaceConvert) {
+ rewriter.setInsertionPoint(parentOp);
+ replaceConvert =
+ fir::ConvertOp::create(rewriter, op->getLoc(), ty, converted);
+ }
+ op->replaceUsesOfWith(firMemref, replaceConvert);
+ erase.push_back(op);
+ }
+
+ for (auto op : erase)
+ worklist.erase(op);
+ }
+}
+
+void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter,
+ FIRToMemRefTypeConverter &typeConverter) {
+ Value firMemref = load.getMemref();
+ if (!typeConverter.convertibleType(firMemref.getType()))
+ return;
+
+ LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: attempting to convert FIR load:\n";
+ load.dump(); firMemref.dump());
+
+ MemRefInfo memrefInfo =
+ getMemRefInfo(firMemref, rewriter, typeConverter, load.getOperation());
+ if (failed(memrefInfo))
+ return;
+
+ auto originalType = load.getResult().getType();
+ auto [converted, indices] = *memrefInfo;
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "FIRToMemRef: convert for FIR load created successfully:\n";
+ converted.dump());
+
+ rewriter.setInsertionPointAfter(load);
+
+ auto attr = (load.getOperation())->getAttr("tbaa");
+ auto loadOp =
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(load, converted, indices);
+ if (attr)
+ loadOp.getOperation()->setAttr("tbaa", attr);
+
+ LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: new memref.load op:\n";
+ loadOp.dump(); assert(succeeded(verify(loadOp))));
+
+ if (isa<fir::LogicalType>(originalType)) {
+ SmallVector<Operation *> loadUsers(
+ loadOp.getOperation()->getUsers().begin(),
+ loadOp.getOperation()->getUsers().end());
+ auto logicalVal =
+ fir::ConvertOp::create(rewriter, loadOp.getLoc(), originalType, loadOp);
+ for (Operation *user : loadUsers) {
+ for (auto &operand : user->getOpOperands()) {
+ if (operand.get() == loadOp)
+ operand.set(logicalVal);
+ }
+ }
+ }
+
+ if (!isa<fir::LogicalType>(originalType))
+ replaceFIRMemrefs(firMemref, converted, rewriter);
+}
+
+void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter,
+ FIRToMemRefTypeConverter &typeConverter) {
+ Value firMemref = store.getMemref();
+
+ if (!typeConverter.convertibleType(firMemref.getType()))
+ return;
+
+ LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: attempting to convert FIR store:\n";
+ store.dump(); firMemref.dump());
+
+ MemRefInfo memrefInfo =
+ getMemRefInfo(firMemref, rewriter, typeConverter, store.getOperation());
+ if (failed(memrefInfo))
+ return;
+
+ auto [converted, indices] = *memrefInfo;
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "FIRToMemRef: convert for FIR store created successfully:\n";
+ converted.dump());
+
+ Value value = store.getValue();
+ rewriter.setInsertionPointAfter(store);
+
+ if (isa<fir::LogicalType>(value.getType())) {
+ auto convertedType = typeConverter.convertType(value.getType());
+ value =
+ fir::ConvertOp::create(rewriter, store.getLoc(), convertedType, value);
+ }
+
+ auto attr = (store.getOperation())->getAttr("tbaa");
+ auto storeOp = rewriter.replaceOpWithNewOp<memref::StoreOp>(
+ store, value, converted, indices);
+ if (attr)
+ storeOp.getOperation()->setAttr("tbaa", attr);
+
+ LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: new memref.store op:\n";
+ storeOp.dump(); assert(succeeded(verify(storeOp))));
+
+ bool isLogicalRef = false;
+ if (auto refTy = llvm::dyn_cast<fir::ReferenceType>(firMemref.getType())) {
+ isLogicalRef = llvm::isa<fir::LogicalType>(refTy.getEleTy());
+ }
+ if (!isLogicalRef)
+ replaceFIRMemrefs(firMemref, converted, rewriter);
+}
+
+void FIRToMemRef::runOnOperation() {
+ LLVM_DEBUG(llvm::dbgs() << "Enter FIRToMemRef()\n");
+
+ auto op = getOperation();
+ auto context = op.getContext();
+ auto mod = op->getParentOfType<ModuleOp>();
+ FIRToMemRefTypeConverter typeConverter(mod);
+
+ typeConverter.setConvertComplexTypes(true);
+
+ PatternRewriter rewriter(context);
+ domInfo = new DominanceInfo(op);
+
+ op.walk([&](fir::AllocaOp alloca) {
+ rewriteAlloca(alloca, rewriter, typeConverter);
+ });
+
+ op.walk([&](Operation *op) {
+ if (auto loadOp = dyn_cast<fir::LoadOp>(op)) {
+ rewriteLoadOp(loadOp, rewriter, typeConverter);
+ } else if (auto storeOp = dyn_cast<fir::StoreOp>(op)) {
+ rewriteStoreOp(storeOp, rewriter, typeConverter);
+ }
+ });
+
+ SmallVector<Operation *> worklist;
+ op->walk([&worklist](Operation *arithOp) {
+ if (llvm::isa<arith::ArithDialect>(arithOp->getDialect())) {
+ if (arithOp->use_empty()) {
+ worklist.push_back(arithOp);
+ }
+ }
+ });
+
+ for (auto eraseOp : eraseOps) {
+ rewriter.eraseOp(eraseOp);
+ }
----------------
clementval wrote:
no braces
https://github.com/llvm/llvm-project/pull/173507
More information about the flang-commits
mailing list