[flang-commits] [flang] [flang] Lowering FIR memory ops to MemRef dialect (PR #173507)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Sat Dec 27 18:19:32 PST 2025
================
@@ -0,0 +1,225 @@
+//===---- FIRToMemRefTypeConverter.h - FIR type conversion to MemRef ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines `FIRToMemRefTypeConverter`, a helper used by the
+// FIR-to-MemRef conversion pass to convert FIR types (scalars, arrays,
+// descriptors) into MemRef types suitable for the MemRef dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_OPTIMIZER_TRANSFORMS_FIRTOMEMREFTYPECONVERTER_H
+#define FORTRAN_OPTIMIZER_TRANSFORMS_FIRTOMEMREFTYPECONVERTER_H
+
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Dialect/Support/FIRContext.h"
+#include "flang/Optimizer/Dialect/Support/KindMapping.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace fir {
+
+class FIRToMemRefTypeConverter : public mlir::TypeConverter {
+private:
+ KindMapping kindMapping;
+ bool convertComplexTypes = false;
+ bool convertScalarTypesOnly = false;
+
+public:
+ explicit FIRToMemRefTypeConverter(mlir::ModuleOp mod)
+ : kindMapping(fir::getKindMapping(mod)) {
+ addConversion([](mlir::Type type) { return type; });
+
+ addConversion([&](fir::LogicalType type) -> mlir::Type {
+ return mlir::IntegerType::get(
+ type.getContext(), kindMapping.getLogicalBitsize(type.getFKind()));
+ });
+
+ addSourceMaterialization([](mlir::OpBuilder &builder, mlir::Type type,
+ mlir::ValueRange inputs,
+ mlir::Location loc) -> mlir::Value {
+ assert(!inputs.empty() && "expected a single input for materialization");
+ builder.setInsertionPointAfter(inputs[0].getDefiningOp());
+ return fir::ConvertOp::create(builder, loc, type, inputs[0]);
+ });
+
+ addTargetMaterialization([](mlir::OpBuilder &builder, mlir::Type type,
+ mlir::ValueRange inputs,
+ mlir::Location loc) -> mlir::Value {
+ return fir::ConvertOp::create(builder, loc, type, inputs[0]);
+ });
+ }
+
+ /// Control whether complex types are considered convertible.
+ void setConvertComplexTypes(bool value) { convertComplexTypes = value; }
+
+ /// Control whether only scalar types are considered during convertibleType.
+ void setConvertScalarTypesOnly(bool value) { convertScalarTypesOnly = value; }
+
+ /// Return true if the given FIR type can be converted to a MemRef-typed
+ /// descriptor (i.e. is a supported base element for MemRef converting).
+ bool convertibleMemrefType(mlir::Type ty) {
+ if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(ty)) {
+ auto elTy = refTy.getElementType();
+ return convertibleMemrefType(elTy);
+ } else if (auto pointerTy = mlir::dyn_cast<fir::PointerType>(ty)) {
+ auto elTy = pointerTy.getElementType();
+ return convertibleMemrefType(elTy);
+ } else if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty)) {
+ auto elTy = heapTy.getElementType();
+ return convertibleMemrefType(elTy);
+ } else if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
+ auto elTy = seqTy.getElementType();
+ return convertibleMemrefType(elTy);
+ } else if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty)) {
+ auto elTy = boxTy.getElementType();
+ return convertibleMemrefType(elTy);
+ }
+
+ setConvertScalarTypesOnly(true);
+ bool result = convertibleType(ty);
+ setConvertScalarTypesOnly(false);
+ return result;
+ }
+
+ /// Return true if the given FIR type represents an empty array (has a zero
+ /// extent in its shape).
+ bool isEmptyArray(mlir::Type ty) const {
+ if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(ty)) {
+ auto elTy = refTy.getElementType();
+ return isEmptyArray(elTy);
+ } else if (auto pointerTy = mlir::dyn_cast<fir::PointerType>(ty)) {
+ auto elTy = pointerTy.getElementType();
+ return isEmptyArray(elTy);
+ } else if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty)) {
+ auto elTy = heapTy.getElementType();
+ return isEmptyArray(elTy);
+ } else if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
+ llvm::ArrayRef<int64_t> firShape = seqTy.getShape();
+ for (auto shape : firShape) {
+ if (shape == 0)
+ return true;
+ }
+ return false;
+ }
+ return false;
+ }
+
+ /// Returns true if the given type can be converted according to the current
+ /// converter settings (scalar-only or full).
+ bool convertibleType(mlir::Type type) const {
+ if (!convertScalarTypesOnly) {
+ if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(type)) {
+ auto elTy = refTy.getElementType();
+ if (mlir::isa<fir::SequenceType>(elTy))
+ return false;
+ return convertibleType(elTy);
+ }
+
+ if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(type)) {
+ auto elTy = seqTy.getElementType();
+ return convertibleType(elTy);
+ }
+ }
+
+ if (fir::isa_fir_type(type)) {
+ if (mlir::isa<fir::LogicalType>(type))
+ return true;
+ return false;
+ }
+
+ if (type.isUnsignedInteger())
+ return false;
+
+ if (mlir::isa<mlir::ComplexType>(type))
+ return convertComplexTypes;
+
+ if (mlir::isa<mlir::FunctionType>(type))
+ return false;
+
+ if (mlir::isa<mlir::TupleType>(type))
+ return false;
+
+ return true;
+ }
+
+ /// Convert a FIR element / aggregate type to a MemRef descriptor type.
+ mlir::MemRefType convertMemrefType(mlir::Type firTy) const {
+ auto convertBaseType = [&](mlir::Type firTy) -> mlir::MemRefType {
+ if (auto charTy = mlir::dyn_cast<fir::CharacterType>(firTy)) {
+ unsigned kind = charTy.getFKind();
+ unsigned bitWidth = kindMapping.getCharacterBitsize(kind);
+ mlir::Type elTy = mlir::IntegerType::get(charTy.getContext(), bitWidth);
+
+ if (charTy.hasConstantLen() && charTy.getLen() == 1) {
+ return mlir::MemRefType::get({}, elTy);
+ } else if (charTy.hasConstantLen()) {
+ int64_t len = charTy.getLen();
+ return mlir::MemRefType::get({len}, elTy);
+ } else {
+ return mlir::MemRefType::get({mlir::ShapedType::kDynamic}, elTy);
+ }
+ }
+
+ if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(firTy)) {
+ auto elTy = seqTy.getElementType();
+ mlir::Type ty = convertType(elTy);
+
+ llvm::ArrayRef<int64_t> firShape = seqTy.getShape();
+ llvm::SmallVector<int64_t> shape;
+ for (auto it = firShape.rbegin(); it != firShape.rend(); ++it)
+ shape.push_back(*it);
+
+ assert(mlir::BaseMemRefType::isValidElementType(ty) &&
+ "got invalid memref element type from array fir type");
+ return mlir::MemRefType::get(shape, ty);
+ }
+
+ mlir::Type ty = convertType(firTy);
+ assert(mlir::BaseMemRefType::isValidElementType(ty) &&
+ "got invalid memref element type from scalar fir type");
+ return mlir::MemRefType::get({}, ty);
+ };
+
+ if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(firTy)) {
+ auto elTy = refTy.getElementType();
+ return convertBaseType(elTy);
+ }
+
+ if (auto pointerTy = mlir::dyn_cast<fir::PointerType>(firTy)) {
+ auto elTy = pointerTy.getElementType();
+ return convertBaseType(elTy);
+ }
+
+ if (auto heapTy = mlir::dyn_cast<fir::HeapType>(firTy)) {
+ auto elTy = heapTy.getElementType();
+ return convertBaseType(elTy);
+ }
----------------
clementval wrote:
You can also simplify the code here.
https://github.com/llvm/llvm-project/pull/173507
More information about the flang-commits
mailing list