[flang-commits] [flang] 044d5b5 - [fir] Add base of the FIR to LLVM IR pass

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Fri Oct 29 14:21:48 PDT 2021


Author: Valentin Clement
Date: 2021-10-29T23:21:43+02:00
New Revision: 044d5b5dd184ebdf20bcdfc62fa4bc7f1efd047c

URL: https://github.com/llvm/llvm-project/commit/044d5b5dd184ebdf20bcdfc62fa4bc7f1efd047c
DIFF: https://github.com/llvm/llvm-project/commit/044d5b5dd184ebdf20bcdfc62fa4bc7f1efd047c.diff

LOG: [fir] Add base of the FIR to LLVM IR pass

This patch adds the base of the FIR to LLVM IR Dialect conversion pass.
It currently can convert the following operations:
 - fir.global
 - fir.has_value
 - fir.address_of
 - fir.undefined

This patch is part of the upstreaming effort from fir-dev branch. It does not
cover all FIR operations in order to have small patches. Several patches will
follow to convert other operations.

Reviewed By: schweitz

Differential Revision: https://reviews.llvm.org/D112845

Added: 
    flang/lib/Optimizer/CodeGen/CodeGen.cpp
    flang/lib/Optimizer/CodeGen/TypeConverter.h
    flang/test/Fir/convert-to-llvm.fir

Modified: 
    flang/include/flang/Optimizer/CodeGen/CGPasses.td
    flang/lib/Optimizer/CodeGen/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/CodeGen/CGPasses.td b/flang/include/flang/Optimizer/CodeGen/CGPasses.td
index ffe829644d1aa..103783b7ae668 100644
--- a/flang/include/flang/Optimizer/CodeGen/CGPasses.td
+++ b/flang/include/flang/Optimizer/CodeGen/CGPasses.td
@@ -16,6 +16,19 @@
 
 include "mlir/Pass/PassBase.td"
 
+def FIRToLLVMLowering : Pass<"fir-to-llvm-ir", "mlir::ModuleOp"> {
+  let summary = "Convert FIR dialect to LLVM-IR dialect";
+  let description = [{
+    Convert the FIR dialect to the LLVM-IR dialect of MLIR. This conversion
+    will also convert ops in the standard and FIRCG dialects.
+  }];
+  let constructor = "::fir::createFIRToLLVMPass()";
+  let dependentDialects = [
+    "fir::FIROpsDialect", "fir::FIRCodeGenDialect", "mlir::BuiltinDialect",
+    "mlir::LLVM::LLVMDialect", "mlir::omp::OpenMPDialect"
+  ];
+}
+
 def CodeGenRewrite : Pass<"cg-rewrite"> {
   let summary = "Rewrite some FIR ops into their code-gen forms.";
   let description = [{

diff  --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 6a8d82cb1f671..d42b38c8a61fa 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_flang_library(FIRCodeGen
   CGOps.cpp
+  CodeGen.cpp
   PreCGRewrite.cpp
 
   DEPENDS

diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
new file mode 100644
index 0000000000000..867d05a7914ba
--- /dev/null
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -0,0 +1,205 @@
+//===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/CodeGen/CodeGen.h"
+#include "PassDetail.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/ArrayRef.h"
+
+#define DEBUG_TYPE "flang-codegen"
+
+// fir::LLVMTypeConverter for converting to LLVM IR dialect types.
+#include "TypeConverter.h"
+
+namespace {
+/// FIR conversion pattern template
+template <typename FromOp>
+class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
+public:
+  explicit FIROpConversion(fir::LLVMTypeConverter &lowering)
+      : mlir::ConvertOpToLLVMPattern<FromOp>(lowering) {}
+
+protected:
+  mlir::Type convertType(mlir::Type ty) const {
+    return lowerTy().convertType(ty);
+  }
+
+  fir::LLVMTypeConverter &lowerTy() const {
+    return *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter());
+  }
+};
+} // namespace
+
+namespace {
+struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> {
+  using FIROpConversion::FIROpConversion;
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    auto ty = convertType(addr.getType());
+    rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
+        addr, ty, addr.symbol().getRootReference().getValue());
+    return success();
+  }
+};
+
+struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> {
+  using FIROpConversion::FIROpConversion;
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
+    return success();
+  }
+};
+
+struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
+  using FIROpConversion::FIROpConversion;
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    auto tyAttr = convertType(global.getType());
+    if (global.getType().isa<fir::BoxType>())
+      tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType();
+    auto loc = global.getLoc();
+    mlir::Attribute initAttr{};
+    if (global.initVal())
+      initAttr = global.initVal().getValue();
+    auto linkage = convertLinkage(global.linkName());
+    auto isConst = global.constant().hasValue();
+    auto g = rewriter.create<mlir::LLVM::GlobalOp>(
+        loc, tyAttr, isConst, linkage, global.sym_name(), initAttr);
+    auto &gr = g.getInitializerRegion();
+    rewriter.inlineRegionBefore(global.region(), gr, gr.end());
+    if (!gr.empty()) {
+      // Replace insert_on_range with a constant dense attribute if the
+      // initialization is on the full range.
+      auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>();
+      for (auto insertOp : insertOnRangeOps) {
+        if (isFullRange(insertOp.coor(), insertOp.getType())) {
+          auto seqTyAttr = convertType(insertOp.getType());
+          auto *op = insertOp.val().getDefiningOp();
+          auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op);
+          if (!constant) {
+            auto convertOp = mlir::dyn_cast<fir::ConvertOp>(op);
+            if (!convertOp)
+              continue;
+            constant = cast<mlir::arith::ConstantOp>(
+                convertOp.value().getDefiningOp());
+          }
+          mlir::Type vecType = mlir::VectorType::get(
+              insertOp.getType().getShape(), constant.getType());
+          auto denseAttr = mlir::DenseElementsAttr::get(
+              vecType.cast<ShapedType>(), constant.value());
+          rewriter.setInsertionPointAfter(insertOp);
+          rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
+              insertOp, seqTyAttr, denseAttr);
+        }
+      }
+    }
+    rewriter.eraseOp(global);
+    return success();
+  }
+
+  bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const {
+    auto extents = seqTy.getShape();
+    if (indexes.size() / 2 != extents.size())
+      return false;
+    for (unsigned i = 0; i < indexes.size(); i += 2) {
+      if (indexes[i].cast<IntegerAttr>().getInt() != 0)
+        return false;
+      if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1)
+        return false;
+    }
+    return true;
+  }
+
+  mlir::LLVM::Linkage convertLinkage(Optional<StringRef> optLinkage) const {
+    if (optLinkage.hasValue()) {
+      auto name = optLinkage.getValue();
+      if (name == "internal")
+        return mlir::LLVM::Linkage::Internal;
+      if (name == "linkonce")
+        return mlir::LLVM::Linkage::Linkonce;
+      if (name == "common")
+        return mlir::LLVM::Linkage::Common;
+      if (name == "weak")
+        return mlir::LLVM::Linkage::Weak;
+    }
+    return mlir::LLVM::Linkage::External;
+  }
+};
+
+// convert to LLVM IR dialect `undef`
+struct UndefOpConversion : public FIROpConversion<fir::UndefOp> {
+  using FIROpConversion::FIROpConversion;
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::UndefOp undef, OpAdaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>(
+        undef, convertType(undef.getType()));
+    return success();
+  }
+};
+} // namespace
+
+namespace {
+/// Convert FIR dialect to LLVM dialect
+///
+/// This pass lowers all FIR dialect operations to LLVM IR dialect. An
+/// MLIR pass is used to lower residual Std dialect to LLVM IR dialect.
+///
+/// This pass is not complete yet. We are upstreaming it in small patches.
+class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
+public:
+  mlir::ModuleOp getModule() { return getOperation(); }
+
+  void runOnOperation() override final {
+    auto *context = getModule().getContext();
+    fir::LLVMTypeConverter typeConverter{getModule()};
+    auto loc = mlir::UnknownLoc::get(context);
+    mlir::OwningRewritePatternList pattern(context);
+    pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion,
+                   UndefOpConversion>(typeConverter);
+    mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
+    mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
+                                                            pattern);
+    mlir::ConversionTarget target{*context};
+    target.addLegalDialect<mlir::LLVM::LLVMDialect>();
+
+    // required NOPs for applying a full conversion
+    target.addLegalOp<mlir::ModuleOp>();
+
+    // apply the patterns
+    if (mlir::failed(mlir::applyFullConversion(getModule(), target,
+                                               std::move(pattern)))) {
+      mlir::emitError(loc, "error in converting to LLVM-IR dialect\n");
+      signalPassFailure();
+    }
+  }
+};
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::createFIRToLLVMPass() {
+  return std::make_unique<FIRToLLVMLowering>();
+}

diff  --git a/flang/lib/Optimizer/CodeGen/TypeConverter.h b/flang/lib/Optimizer/CodeGen/TypeConverter.h
new file mode 100644
index 0000000000000..fe63da90b8ecb
--- /dev/null
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.h
@@ -0,0 +1,85 @@
+//===-- TypeConverter.h -- type conversion ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_OPTIMIZER_CODEGEN_TYPECONVERTER_H
+#define FORTRAN_OPTIMIZER_CODEGEN_TYPECONVERTER_H
+
+#include "llvm/Support/Debug.h"
+
+namespace fir {
+
+/// FIR type converter
+/// This converts FIR types to LLVM types (for now)
+class LLVMTypeConverter : public mlir::LLVMTypeConverter {
+public:
+  LLVMTypeConverter(mlir::ModuleOp module)
+      : mlir::LLVMTypeConverter(module.getContext()) {
+    LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n");
+
+    // Each conversion should return a value of type mlir::Type.
+    addConversion(
+        [&](fir::ReferenceType ref) { return convertPointerLike(ref); });
+    addConversion(
+        [&](SequenceType sequence) { return convertSequenceType(sequence); });
+  }
+
+  template <typename A>
+  mlir::Type convertPointerLike(A &ty) {
+    mlir::Type eleTy = ty.getEleTy();
+    // A sequence type is a special case. A sequence of runtime size on its
+    // interior dimensions lowers to a memory reference. In that case, we
+    // degenerate the array and do not want a the type to become `T**` but
+    // merely `T*`.
+    if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>()) {
+      if (!seqTy.hasConstantShape() ||
+          characterWithDynamicLen(seqTy.getEleTy())) {
+        if (seqTy.hasConstantInterior())
+          return convertType(seqTy);
+        eleTy = seqTy.getEleTy();
+      }
+    }
+    // fir.ref<fir.box> is a special case because fir.box type is already
+    // a pointer to a Fortran descriptor at the LLVM IR level. This implies
+    // that a fir.ref<fir.box>, that is the address of fir.box is actually
+    // the same as a fir.box at the LLVM level.
+    // The distinction is kept in fir to denote when a descriptor is expected
+    // to be mutable (fir.ref<fir.box>) and when it is not (fir.box).
+    if (eleTy.isa<fir::BoxType>())
+      return convertType(eleTy);
+
+    return mlir::LLVM::LLVMPointerType::get(convertType(eleTy));
+  }
+
+  // fir.array<c ... :any>  -->  llvm<"[...[c x any]]">
+  mlir::Type convertSequenceType(SequenceType seq) {
+    auto baseTy = convertType(seq.getEleTy());
+    if (characterWithDynamicLen(seq.getEleTy()))
+      return mlir::LLVM::LLVMPointerType::get(baseTy);
+    auto shape = seq.getShape();
+    auto constRows = seq.getConstantRows();
+    if (constRows) {
+      decltype(constRows) i = constRows;
+      for (auto e : shape) {
+        baseTy = mlir::LLVM::LLVMArrayType::get(baseTy, e);
+        if (--i == 0)
+          break;
+      }
+      if (seq.hasConstantShape())
+        return baseTy;
+    }
+    return mlir::LLVM::LLVMPointerType::get(baseTy);
+  }
+};
+
+} // namespace fir
+
+#endif // FORTRAN_OPTIMIZER_CODEGEN_TYPECONVERTER_H

diff  --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
new file mode 100644
index 0000000000000..9e1b02590f193
--- /dev/null
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -0,0 +1,83 @@
+// RUN: fir-opt --split-input-file --fir-to-llvm-ir %s | FileCheck %s
+
+// Test simple global LLVM conversion
+
+fir.global @g_i0 : i32 {
+  %1 = arith.constant 0 : i32
+  fir.has_value %1 : i32
+}
+
+// CHECK: llvm.mlir.global external @g_i0() : i32 {
+// CHECK:   %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK:   llvm.return %[[C0]] : i32
+// CHECK: }
+
+// -----
+
+fir.global @g_ci5 constant : i32 {
+  %c = arith.constant 5 : i32
+  fir.has_value %c : i32
+}
+
+// CHECK: llvm.mlir.global external constant @g_ci5() : i32 {
+// CHECK:   %[[C5:.*]] = llvm.mlir.constant(5 : i32) : i32
+// CHECK:   llvm.return %[[C5]] : i32
+// CHECK: }
+
+// -----
+
+fir.global internal @i_i515 (515:i32) : i32
+// CHECK: llvm.mlir.global internal @i_i515(515 : i32) : i32
+
+// -----
+
+fir.global common @C_i511 (0:i32) : i32
+// CHECK: llvm.mlir.global common @C_i511(0 : i32) : i32
+
+// -----
+
+fir.global weak @w_i86 (86:i32) : i32
+// CHECK: llvm.mlir.global weak @w_i86(86 : i32) : i32
+
+// -----
+
+fir.global linkonce @w_i86 (86:i32) : i32
+// CHECK: llvm.mlir.global linkonce @w_i86(86 : i32) : i32
+
+// -----
+
+// Test conversion of fir.address_of with fir.global
+
+func @f1() {
+  %0 = fir.address_of(@symbol) : !fir.ref<i64>
+  return
+}
+
+fir.global @symbol : i64 {
+  %0 = arith.constant 1 : i64
+  fir.has_value %0 : i64
+}
+
+// CHECK: %{{.*}} = llvm.mlir.addressof @[[SYMBOL:.*]] : !llvm.ptr<i64>
+
+// CHECK: llvm.mlir.global external @[[SYMBOL]]() : i64 {
+// CHECK:   %{{.*}} = llvm.mlir.constant(1 : i64) : i64
+// CHECK:   llvm.return %{{.*}} : i64
+// CHECK: }
+
+// -----
+
+// Test global with insert_on_range operation covering the full array
+// in initializer region.
+
+fir.global internal @_QEmultiarray : !fir.array<32x32xi32> {
+  %c0_i32 = arith.constant 1 : i32
+  %0 = fir.undefined !fir.array<32x32xi32>
+  %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 31 : index, 0 : index, 31 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
+  fir.has_value %2 : !fir.array<32x32xi32>
+}
+
+// CHECK: llvm.mlir.global internal @_QEmultiarray() : !llvm.array<32 x array<32 x i32>> {
+// CHECK:   %[[CST:.*]] = llvm.mlir.constant(dense<1> : vector<32x32xi32>) : !llvm.array<32 x array<32 x i32>>
+// CHECK:   llvm.return %[[CST]] : !llvm.array<32 x array<32 x i32>>
+// CHECK: }


        


More information about the flang-commits mailing list