[flang-commits] [flang] [flang] Lowering FIR memory ops to MemRef dialect (PR #173507)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Sat Dec 27 18:25:02 PST 2025


================
@@ -0,0 +1,1156 @@
+//===-- FIRToMemRef.cpp - Convert FIR loads and stores to MemRef ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass lowers FIR dialect memory operations to the MemRef dialect.
+// In particular it:
+//
+//  - Rewrites `fir.alloca` to `memref.alloca`.
+//
+//  - Rewrites `fir.load` / `fir.store` to `memref.load` / `memref.store`.
+//
+//  - Allows FIR and MemRef to coexist by introducing `fir.convert` at
+//    memory-use sites. Memory operations (`memref.load`, `memref.store`,
+//    `memref.reinterpret_cast`, etc.) see MemRef-typed values, while the
+//    original FIR-typed values remain available for non-memory uses. For
+//    example:
+//
+//        %fir_ref = ... : !fir.ref<!fir.array<...>>
+//        %memref = fir.convert %fir_ref
+//                    : !fir.ref<!fir.array<...>> -> memref<...>
+//        %val = memref.load %memref[...] : memref<...>
+//        fir.call @callee(%fir_ref) : (!fir.ref<!fir.array<...>>) -> ()
+//
+//    Here the MemRef-typed value is used for `memref.load`, while the
+//    original FIR-typed value is preserved for `fir.call`.
+//
+//  - Computes shapes, strides, and indices as needed for slices and shifts
+//    and emits `memref.reinterpret_cast` when dynamic layout is required
+//    (TODO: use memref.cast instead).
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Transforms/FIRToMemRefTypeConverter.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+
+#include "flang/Optimizer/Builder/CUFCommon.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIROpsSupport.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Dialect/Support/FIRContext.h"
+#include "flang/Optimizer/Dialect/Support/KindMapping.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#define DEBUG_TYPE "fir-to-memref"
+
+using namespace mlir;
+
+namespace fir {
+
+#define GEN_PASS_DEF_FIRTOMEMREF
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+
+static bool isMarshalLike(Operation *op) {
+  if (!op)
+    return false;
+
+  auto convert = dyn_cast<fir::ConvertOp>(op);
+  if (!convert)
+    return false;
+
+  bool resIsMemRef = isa<MemRefType>(convert.getType());
+  bool argIsMemRef = isa<MemRefType>(convert.getValue().getType());
+
+  assert(!(resIsMemRef && argIsMemRef) &&
+         "unexpected fir.convert memref -> memref in isMarshalLike");
+
+  return resIsMemRef || argIsMemRef;
+}
+
+using MemRefInfo = FailureOr<std::pair<Value, SmallVector<Value>>>;
+
+static llvm::cl::opt<bool> enableFIRConvertOptimizations(
+    "enable-fir-convert-opts",
+    llvm::cl::desc("enable emilinating redundant fir.convert in FIR-to-MemRef"),
+    llvm::cl::init(false), llvm::cl::Hidden);
+
+class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
+public:
+  void runOnOperation() override;
+
+private:
+  llvm::SmallSetVector<Operation *, 32> eraseOps;
+
+  DominanceInfo *domInfo = nullptr;
+
+  void rewriteAlloca(fir::AllocaOp, PatternRewriter &,
+                     FIRToMemRefTypeConverter &);
+
+  void rewriteLoadOp(fir::LoadOp, PatternRewriter &,
+                     FIRToMemRefTypeConverter &);
+
+  void rewriteStoreOp(fir::StoreOp, PatternRewriter &,
+                      FIRToMemRefTypeConverter &);
+
+  MemRefInfo getMemRefInfo(Value, PatternRewriter &, FIRToMemRefTypeConverter &,
+                           Operation *);
+
+  MemRefInfo convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp,
+                                PatternRewriter &, FIRToMemRefTypeConverter &);
+
+  void replaceFIRMemrefs(Value, Value, PatternRewriter &) const;
+
+  FailureOr<Value> getFIRConvert(Operation *memOp, Operation *memref,
+                                 PatternRewriter &, FIRToMemRefTypeConverter &);
+
+  FailureOr<SmallVector<Value>> getMemrefIndices(fir::ArrayCoorOp, Operation *,
+                                                 PatternRewriter &, Value,
+                                                 Value) const;
+
+  bool memrefIsOptional(Operation *) const;
+
+  Value canonicalizeIndex(Value, PatternRewriter &) const;
+
+  template <typename OpTy>
+  void getShapeFrom(OpTy op, SmallVector<Value> &shapeVec,
+                    SmallVector<Value> &shiftVec,
+                    SmallVector<Value> &sliceVec) const;
+
+  void populateShapeAndShift(SmallVectorImpl<Value> &shapeVec,
+                             SmallVectorImpl<Value> &shiftVec,
+                             fir::ShapeShiftOp shift) const;
+
+  void populateShift(SmallVectorImpl<Value> &vec, fir::ShiftOp shift) const;
+
+  void populateShape(SmallVectorImpl<Value> &vec, fir::ShapeOp shape) const;
+
+  unsigned getRankFromEmbox(fir::EmboxOp embox) const {
+    auto memrefType = embox.getMemref().getType();
+    Type unwrappedType = fir::unwrapRefType(memrefType);
+    if (auto seqType = dyn_cast<fir::SequenceType>(unwrappedType))
+      return seqType.getDimension();
+    return 0;
+  }
+
+  bool isCompilerGeneratedAlloca(Operation *op) const;
+
+  void copyAttribute(Operation *from, Operation *to,
+                     llvm::StringRef name) const;
+
+  Type getBaseType(Type type, bool complexBaseTypes = false) const;
+
+  bool memrefIsDeviceData(Operation *memref) const;
+
+  bool isMarshalLikeOp(Operation *op) const;
+
+  mlir::Attribute findCudaDataAttr(Value val) const;
+};
+
+void FIRToMemRef::populateShapeAndShift(SmallVectorImpl<Value> &shapeVec,
+                                        SmallVectorImpl<Value> &shiftVec,
+                                        fir::ShapeShiftOp shift) const {
+  for (auto i = shift.getPairs().begin(), endIter = shift.getPairs().end();
+       i != endIter;) {
+    shiftVec.push_back(*i++);
+    shapeVec.push_back(*i++);
+  }
+}
+
+bool FIRToMemRef::isCompilerGeneratedAlloca(Operation *op) const {
+  if (!isa<fir::AllocaOp, memref::AllocaOp>(op))
+    llvm_unreachable("expected alloca op");
+
+  return !op->getAttr("bindc_name") && !op->getAttr("uniq_name");
+}
+
+void FIRToMemRef::copyAttribute(Operation *from, Operation *to,
+                                llvm::StringRef name) const {
+  if (auto value = from->getAttr(name))
+    to->setAttr(name, value);
+}
+
+Type FIRToMemRef::getBaseType(Type type, bool complexBaseTypes) const {
+  if (fir::isa_fir_type(type)) {
+    type = fir::unwrapAllRefAndSeqType(type);
+    type = fir::unwrapSeqOrBoxedSeqType(type);
+  } else if (auto memrefTy = dyn_cast<MemRefType>(type)) {
+    type = memrefTy.getElementType();
+  }
+
+  if (!complexBaseTypes) {
+    if (auto complexTy = dyn_cast<ComplexType>(type))
+      type = complexTy.getElementType();
+  }
+  return type;
+}
+
+bool FIRToMemRef::memrefIsDeviceData(Operation *memref) const {
+  if (isa<ACC_DATA_ENTRY_OPS>(memref))
+    return true;
+
+  if (auto cudaAttr = cuf::getDataAttr(memref)) {
+    auto attrValue = cudaAttr.getValue();
+    return attrValue == cuf::DataAttribute::Device ||
+           attrValue == cuf::DataAttribute::Managed ||
+           attrValue == cuf::DataAttribute::Constant ||
+           attrValue == cuf::DataAttribute::Shared ||
+           attrValue == cuf::DataAttribute::Unified;
+  }
+  return false;
+}
+
+bool FIRToMemRef::isMarshalLikeOp(Operation *op) const {
+  if (!op)
+    return false;
+
+  auto convert = dyn_cast<fir::ConvertOp>(op);
+  if (convert) {
+    bool resIsMemRef = isa<MemRefType>(convert.getType());
+    bool argIsMemRef = isa<MemRefType>(convert.getValue().getType());
+    assert(!(resIsMemRef && argIsMemRef) &&
+           "unexpected fir.convert memref -> memref in isMarshalLikeOp");
+  }
+
+  auto isaPolymorphicConversion = [](fir::ConvertOp c) {
+    bool retVal{false};
+    if (auto fromBoxTy{dyn_cast<fir::ClassType>(
+            fir::unwrapRefType(c.getValue().getType()))}) {
+      if (auto toBoxTy{
+              dyn_cast<fir::BaseBoxType>(fir::unwrapRefType(c.getType()))}) {
+        auto fromEleTy{fir::unwrapAllRefAndSeqType(fromBoxTy.getEleTy())};
+        auto toEleTy{fir::unwrapAllRefAndSeqType(toBoxTy.getEleTy())};
+        if (fromEleTy != toEleTy)
+          retVal = true;
+      }
+    }
+    return retVal;
+  };
+
+  return convert && !isaPolymorphicConversion(convert) &&
+         (isa<MemRefType>(convert.getType()) ||
+          isa<MemRefType>(convert.getValue().getType()));
+}
+
+mlir::Attribute FIRToMemRef::findCudaDataAttr(Value val) const {
+  Value currentVal = val;
+  llvm::SmallPtrSet<Operation *, 8> visited;
+
+  while (currentVal) {
+    auto defOp = currentVal.getDefiningOp();
+    if (!defOp || !visited.insert(defOp).second)
+      break;
+
+    if (auto cudaAttr = cuf::getDataAttr(defOp))
+      return cudaAttr;
+
+    if (auto reboxOp = dyn_cast<fir::ReboxOp>(defOp)) {
+      currentVal = reboxOp.getBox();
+    } else if (auto emboxOp = dyn_cast<fir::EmboxOp>(defOp)) {
+      currentVal = emboxOp.getMemref();
+    } else if (auto declareOp = dyn_cast<fir::DeclareOp>(defOp)) {
+      currentVal = declareOp.getMemref();
+    } else {
+      break;
+    }
+  }
+  return nullptr;
+}
+
+void FIRToMemRef::populateShift(SmallVectorImpl<Value> &vec,
+                                fir::ShiftOp shift) const {
+  vec.append(shift.getOrigins().begin(), shift.getOrigins().end());
+}
+
+void FIRToMemRef::populateShape(SmallVectorImpl<Value> &vec,
+                                fir::ShapeOp shape) const {
+  vec.append(shape.getExtents().begin(), shape.getExtents().end());
+}
+
+template <typename OpTy>
+void FIRToMemRef::getShapeFrom(OpTy op, SmallVector<Value> &shapeVec,
+                               SmallVector<Value> &shiftVec,
+                               SmallVector<Value> &sliceVec) const {
+  if constexpr (std::is_same_v<OpTy, fir::ArrayCoorOp> ||
+                std::is_same_v<OpTy, fir::ReboxOp> ||
+                std::is_same_v<OpTy, fir::EmboxOp>) {
+    Value shapeVal = op.getShape();
+
+    if (shapeVal) {
+      Operation *shapeValOp = shapeVal.getDefiningOp();
+
+      if (auto shapeOp = dyn_cast<fir::ShapeOp>(shapeValOp)) {
+        populateShape(shapeVec, shapeOp);
+      } else if (auto shapeShiftOp = dyn_cast<fir::ShapeShiftOp>(shapeValOp)) {
+        populateShapeAndShift(shapeVec, shiftVec, shapeShiftOp);
+      } else if (auto shiftOp = dyn_cast<fir::ShiftOp>(shapeValOp)) {
+        populateShift(shiftVec, shiftOp);
+      }
+    }
+
+    Value sliceVal = op.getSlice();
+    if (sliceVal) {
+      if (auto sliceOp = sliceVal.getDefiningOp<fir::SliceOp>()) {
+        auto triples = sliceOp.getTriples();
+        sliceVec.append(triples.begin(), triples.end());
+      }
+    }
+  }
+}
+
+void FIRToMemRef::rewriteAlloca(fir::AllocaOp firAlloca,
+                                PatternRewriter &rewriter,
+                                FIRToMemRefTypeConverter &typeConverter) {
+  if (!typeConverter.convertibleType(firAlloca.getInType()))
+    return;
+
+  if (typeConverter.isEmptyArray(firAlloca.getType()))
+    return;
+
+  rewriter.setInsertionPointAfter(firAlloca);
+
+  Type type = firAlloca.getType();
+  MemRefType memrefTy = typeConverter.convertMemrefType(type);
+
+  Location loc = firAlloca.getLoc();
+
+  SmallVector<Value> sizes = firAlloca.getOperands();
+  std::reverse(sizes.begin(), sizes.end());
+
+  auto alloca = memref::AllocaOp::create(rewriter, loc, memrefTy, sizes);
+  copyAttribute(firAlloca, alloca, firAlloca.getBindcNameAttrName());
+  copyAttribute(firAlloca, alloca, firAlloca.getUniqNameAttrName());
+  copyAttribute(firAlloca, alloca, cuf::getDataAttrName());
+
+  auto convert = fir::ConvertOp::create(rewriter, loc, type, alloca);
+
+  rewriter.replaceOp(firAlloca, convert);
+
+  if (isCompilerGeneratedAlloca(alloca)) {
+    for (Operation *userOp : convert->getUsers()) {
+      if (auto declareOp = dyn_cast<fir::DeclareOp>(userOp)) {
+        LLVM_DEBUG(llvm::dbgs()
+                       << "FIRToMemRef: removing declare for compiler temp:\n";
+                   declareOp->dump());
+        declareOp->replaceAllUsesWith(convert);
+        eraseOps.insert(userOp);
+      }
+    }
+  }
+}
+
+bool FIRToMemRef::memrefIsOptional(Operation *op) const {
+  if (auto declare = dyn_cast<fir::DeclareOp>(op)) {
+    Value operand = declare.getMemref();
+
+    if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
+      if (auto func =
+              dyn_cast<func::FuncOp>((blockArg.getOwner())->getParentOp())) {
+        if (func.getArgAttr(blockArg.getArgNumber(),
+                            fir::getOptionalAttrName()))
+          return true;
+      }
+    }
+
+    Operation *operandOp = operand.getDefiningOp();
+    if (operandOp && isa<fir::AbsentOp>(operandOp))
+      return true;
+  }
+
+  for (mlir::Value result : op->getResults()) {
+    for (mlir::Operation *userOp : result.getUsers()) {
+      if (isa<fir::IsPresentOp>(userOp))
+        return true;
+    }
+  }
----------------
clementval wrote:

No braces

https://github.com/llvm/llvm-project/pull/173507


More information about the flang-commits mailing list