[flang-commits] [flang] [flang] Lowering FIR memory ops to MemRef dialect (PR #173507)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Sat Dec 27 18:22:05 PST 2025
================
@@ -0,0 +1,1156 @@
+//===-- FIRToMemRef.cpp - Convert FIR loads and stores to MemRef ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass lowers FIR dialect memory operations to the MemRef dialect.
+// In particular it:
+//
+// - Rewrites `fir.alloca` to `memref.alloca`.
+//
+// - Rewrites `fir.load` / `fir.store` to `memref.load` / `memref.store`.
+//
+// - Allows FIR and MemRef to coexist by introducing `fir.convert` at
+// memory-use sites. Memory operations (`memref.load`, `memref.store`,
+// `memref.reinterpret_cast`, etc.) see MemRef-typed values, while the
+// original FIR-typed values remain available for non-memory uses. For
+// example:
+//
+// %fir_ref = ... : !fir.ref<!fir.array<...>>
+// %memref = fir.convert %fir_ref
+// : !fir.ref<!fir.array<...>> -> memref<...>
+// %val = memref.load %memref[...] : memref<...>
+// fir.call @callee(%fir_ref) : (!fir.ref<!fir.array<...>>) -> ()
+//
+// Here the MemRef-typed value is used for `memref.load`, while the
+// original FIR-typed value is preserved for `fir.call`.
+//
+// - Computes shapes, strides, and indices as needed for slices and shifts
+// and emits `memref.reinterpret_cast` when dynamic layout is required
+// (TODO: use memref.cast instead).
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Transforms/FIRToMemRefTypeConverter.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+
+#include "flang/Optimizer/Builder/CUFCommon.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIROpsSupport.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Dialect/Support/FIRContext.h"
+#include "flang/Optimizer/Dialect/Support/KindMapping.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#define DEBUG_TYPE "fir-to-memref"
+
+using namespace mlir;
+
+namespace fir {
+
+#define GEN_PASS_DEF_FIRTOMEMREF
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+
+static bool isMarshalLike(Operation *op) {
+ if (!op)
+ return false;
+
+ auto convert = dyn_cast<fir::ConvertOp>(op);
+ if (!convert)
+ return false;
+
+ bool resIsMemRef = isa<MemRefType>(convert.getType());
+ bool argIsMemRef = isa<MemRefType>(convert.getValue().getType());
+
+ assert(!(resIsMemRef && argIsMemRef) &&
+ "unexpected fir.convert memref -> memref in isMarshalLike");
+
+ return resIsMemRef || argIsMemRef;
+}
+
+using MemRefInfo = FailureOr<std::pair<Value, SmallVector<Value>>>;
+
+static llvm::cl::opt<bool> enableFIRConvertOptimizations(
+ "enable-fir-convert-opts",
+ llvm::cl::desc("enable emilinating redundant fir.convert in FIR-to-MemRef"),
+ llvm::cl::init(false), llvm::cl::Hidden);
+
+class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
+public:
+ void runOnOperation() override;
+
+private:
+ llvm::SmallSetVector<Operation *, 32> eraseOps;
+
+ DominanceInfo *domInfo = nullptr;
+
+ void rewriteAlloca(fir::AllocaOp, PatternRewriter &,
+ FIRToMemRefTypeConverter &);
+
+ void rewriteLoadOp(fir::LoadOp, PatternRewriter &,
+ FIRToMemRefTypeConverter &);
+
+ void rewriteStoreOp(fir::StoreOp, PatternRewriter &,
+ FIRToMemRefTypeConverter &);
+
+ MemRefInfo getMemRefInfo(Value, PatternRewriter &, FIRToMemRefTypeConverter &,
+ Operation *);
+
+ MemRefInfo convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp,
+ PatternRewriter &, FIRToMemRefTypeConverter &);
+
+ void replaceFIRMemrefs(Value, Value, PatternRewriter &) const;
+
+ FailureOr<Value> getFIRConvert(Operation *memOp, Operation *memref,
+ PatternRewriter &, FIRToMemRefTypeConverter &);
+
+ FailureOr<SmallVector<Value>> getMemrefIndices(fir::ArrayCoorOp, Operation *,
+ PatternRewriter &, Value,
+ Value) const;
+
+ bool memrefIsOptional(Operation *) const;
+
+ Value canonicalizeIndex(Value, PatternRewriter &) const;
+
+ template <typename OpTy>
+ void getShapeFrom(OpTy op, SmallVector<Value> &shapeVec,
+ SmallVector<Value> &shiftVec,
+ SmallVector<Value> &sliceVec) const;
+
+ void populateShapeAndShift(SmallVectorImpl<Value> &shapeVec,
+ SmallVectorImpl<Value> &shiftVec,
+ fir::ShapeShiftOp shift) const;
+
+ void populateShift(SmallVectorImpl<Value> &vec, fir::ShiftOp shift) const;
+
+ void populateShape(SmallVectorImpl<Value> &vec, fir::ShapeOp shape) const;
+
+ unsigned getRankFromEmbox(fir::EmboxOp embox) const {
+ auto memrefType = embox.getMemref().getType();
+ Type unwrappedType = fir::unwrapRefType(memrefType);
+ if (auto seqType = dyn_cast<fir::SequenceType>(unwrappedType))
+ return seqType.getDimension();
+ return 0;
+ }
+
+ bool isCompilerGeneratedAlloca(Operation *op) const;
+
+ void copyAttribute(Operation *from, Operation *to,
+ llvm::StringRef name) const;
+
+ Type getBaseType(Type type, bool complexBaseTypes = false) const;
+
+ bool memrefIsDeviceData(Operation *memref) const;
+
+ bool isMarshalLikeOp(Operation *op) const;
+
+ mlir::Attribute findCudaDataAttr(Value val) const;
+};
+
+void FIRToMemRef::populateShapeAndShift(SmallVectorImpl<Value> &shapeVec,
+ SmallVectorImpl<Value> &shiftVec,
+ fir::ShapeShiftOp shift) const {
+ for (auto i = shift.getPairs().begin(), endIter = shift.getPairs().end();
+ i != endIter;) {
+ shiftVec.push_back(*i++);
+ shapeVec.push_back(*i++);
+ }
+}
+
+bool FIRToMemRef::isCompilerGeneratedAlloca(Operation *op) const {
+ if (!isa<fir::AllocaOp, memref::AllocaOp>(op))
+ llvm_unreachable("expected alloca op");
+
+ return !op->getAttr("bindc_name") && !op->getAttr("uniq_name");
+}
+
+void FIRToMemRef::copyAttribute(Operation *from, Operation *to,
+ llvm::StringRef name) const {
+ if (auto value = from->getAttr(name))
----------------
clementval wrote:
Spell auto
https://github.com/llvm/llvm-project/pull/173507
More information about the flang-commits
mailing list