[clang] [CIR] Initial implementation of lowering CIR to MLIR (PR #127835)

Aaron Ballman via cfe-commits cfe-commits at lists.llvm.org
Thu Feb 20 06:41:42 PST 2025


================
@@ -0,0 +1,201 @@
+//====- LowerCIRToMLIR.cpp - Lowering from CIR to MLIR --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of CIR operations to MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/IR/CIRTypes.h"
+#include "clang/CIR/LowerToLLVM.h"
+#include "clang/CIR/MissingFeatures.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/TimeProfiler.h"
+
+using namespace cir;
+using namespace llvm;
+
+namespace cir {
+
+struct ConvertCIRToMLIRPass
+    : public mlir::PassWrapper<ConvertCIRToMLIRPass,
+                               mlir::OperationPass<mlir::ModuleOp>> {
+  void getDependentDialects(mlir::DialectRegistry &registry) const override {
+    registry.insert<mlir::BuiltinDialect, mlir::memref::MemRefDialect>();
+  }
+  void runOnOperation() final;
+
+  StringRef getDescription() const override {
+    return "Convert the CIR dialect module to MLIR standard dialects";
+  }
+
+  StringRef getArgument() const override { return "cir-to-mlir"; }
+};
+
+class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
+public:
+  using OpConversionPattern<cir::GlobalOp>::OpConversionPattern;
+  mlir::LogicalResult
+  matchAndRewrite(cir::GlobalOp op, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    auto moduleOp = op->getParentOfType<mlir::ModuleOp>();
+    if (!moduleOp)
+      return mlir::failure();
+
+    mlir::OpBuilder b(moduleOp.getContext());
+
+    const auto cirSymType = op.getSymType();
+    assert(!cir::MissingFeatures::convertTypeForMemory());
+    auto convertedType = getTypeConverter()->convertType(cirSymType);
+    if (!convertedType)
+      return mlir::failure();
+    auto memrefType = dyn_cast<mlir::MemRefType>(convertedType);
+    if (!memrefType)
+      memrefType = mlir::MemRefType::get({}, convertedType);
+    // Add an optional alignment to the global memref.
+    assert(!cir::MissingFeatures::opGlobalAlignment());
+    mlir::IntegerAttr memrefAlignment = mlir::IntegerAttr();
+    // Add an optional initial value to the global memref.
+    mlir::Attribute initialValue = mlir::Attribute();
+    std::optional<mlir::Attribute> init = op.getInitialValue();
+    if (init.has_value()) {
+      initialValue =
+          llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init.value())
+              .Case<cir::IntAttr>([&](cir::IntAttr attr) {
+                auto rtt = mlir::RankedTensorType::get({}, convertedType);
+                return mlir::DenseIntElementsAttr::get(rtt, attr.getValue());
+              })
+              .Case<cir::FPAttr>([&](cir::FPAttr attr) {
+                auto rtt = mlir::RankedTensorType::get({}, convertedType);
+                return mlir::DenseFPElementsAttr::get(rtt, attr.getValue());
+              })
+              .Default([&](mlir::Attribute attr) {
+                llvm_unreachable("GlobalOp lowering with initial value is not "
+                                 "fully supported yet");
+                return mlir::Attribute();
+              });
+    }
+
+    // Add symbol visibility
+    assert(!cir::MissingFeatures::opGlobalLinkage());
+    std::string symVisibility = "public";
+
+    assert(!cir::MissingFeatures::opGlobalConstant());
+    bool isConstant = false;
+
+    rewriter.replaceOpWithNewOp<mlir::memref::GlobalOp>(
+        op, b.getStringAttr(op.getSymName()),
+        /*sym_visibility=*/b.getStringAttr(symVisibility),
+        /*type=*/memrefType, initialValue,
+        /*constant=*/isConstant,
+        /*alignment=*/memrefAlignment);
+
+    return mlir::success();
+  }
+};
+
+void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
+                                         mlir::TypeConverter &converter) {
+  patterns.add<CIRGlobalOpLowering>(converter, patterns.getContext());
+}
+
+static mlir::TypeConverter prepareTypeConverter() {
+  mlir::TypeConverter converter;
+  converter.addConversion([&](cir::PointerType type) -> mlir::Type {
+    assert(!cir::MissingFeatures::convertTypeForMemory());
+    mlir::Type ty = converter.convertType(type.getPointee());
+    // FIXME: The pointee type might not be converted (e.g. struct)
+    if (!ty)
+      return nullptr;
+    return mlir::MemRefType::get({}, ty);
+  });
+  converter.addConversion(
+      [&](mlir::IntegerType type) -> mlir::Type { return type; });
+  converter.addConversion(
+      [&](mlir::FloatType type) -> mlir::Type { return type; });
+  converter.addConversion([&](cir::VoidType type) -> mlir::Type { return {}; });
+  converter.addConversion([&](cir::IntType type) -> mlir::Type {
+    // arith dialect ops doesn't take signed integer -- drop cir sign here
+    return mlir::IntegerType::get(
+        type.getContext(), type.getWidth(),
+        mlir::IntegerType::SignednessSemantics::Signless);
+  });
+  converter.addConversion([&](cir::SingleType type) -> mlir::Type {
+    return mlir::Float32Type::get(type.getContext());
+  });
+  converter.addConversion([&](cir::DoubleType type) -> mlir::Type {
+    return mlir::Float64Type::get(type.getContext());
+  });
+  converter.addConversion([&](cir::FP80Type type) -> mlir::Type {
+    return mlir::Float80Type::get(type.getContext());
+  });
+  converter.addConversion([&](cir::LongDoubleType type) -> mlir::Type {
+    return converter.convertType(type.getUnderlying());
+  });
+  converter.addConversion([&](cir::FP128Type type) -> mlir::Type {
+    return mlir::Float128Type::get(type.getContext());
+  });
+  converter.addConversion([&](cir::FP16Type type) -> mlir::Type {
+    return mlir::Float16Type::get(type.getContext());
+  });
+  converter.addConversion([&](cir::BF16Type type) -> mlir::Type {
+    return mlir::BFloat16Type::get(type.getContext());
+  });
+
+  return converter;
+}
+
+void ConvertCIRToMLIRPass::runOnOperation() {
+  auto module = getOperation();
+
+  auto converter = prepareTypeConverter();
+
+  mlir::RewritePatternSet patterns(&getContext());
+
+  populateCIRToMLIRConversionPatterns(patterns, converter);
+
+  mlir::ConversionTarget target(getContext());
+  target.addLegalOp<mlir::ModuleOp>();
+  target.addLegalDialect<mlir::memref::MemRefDialect>();
+  target.addIllegalDialect<cir::CIRDialect>();
+
+  if (failed(applyPartialConversion(module, target, std::move(patterns))))
+    signalPassFailure();
+}
+
+std::unique_ptr<mlir::Pass> createConvertCIRToMLIRPass() {
+  return std::make_unique<ConvertCIRToMLIRPass>();
+}
+
+mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp mlirModule,
+                                  mlir::MLIRContext &mlirCtx) {
+  llvm::TimeTraceScope scope("Lower CIR To MLIR");
+
+  mlir::PassManager pm(&mlirCtx);
+
+  pm.addPass(createConvertCIRToMLIRPass());
+
+  auto result = !mlir::failed(pm.run(mlirModule));
----------------
AaronBallman wrote:

Please spell out the type.

https://github.com/llvm/llvm-project/pull/127835


More information about the cfe-commits mailing list