[clang] [CIR] Realign CIR-to-LLVM IR lowering code with incubator (PR #129293)
Andy Kaylor via cfe-commits
cfe-commits at lists.llvm.org
Fri Feb 28 13:36:12 PST 2025
https://github.com/andykaylor updated https://github.com/llvm/llvm-project/pull/129293
>From bb41af68d0d0f66c5610c69d6deb8a615d644fe5 Mon Sep 17 00:00:00 2001
From: Andy Kaylor <akaylor at nvidia.com>
Date: Fri, 28 Feb 2025 10:54:09 -0800
Subject: [PATCH 1/3] [CIR] Replace CIRAttrVisitor with TypeSwitch
We previously discussed having an mlir-tblgen utility to complete the
CIRAttrVisitor implementation with all support attribute types, but
when I proposed an implementation to do this, a reviewer suggested
using TypeSwitch instead, and I have done that in the incubator.
See https://github.com/llvm/llvm-project/pull/126332
This change brings the TypeSwitch implementation into the upstream repo
to replace the visitor class.
---
.../clang/CIR/Dialect/IR/CIRAttrVisitor.h | 52 -------------------
.../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 31 +++++++----
2 files changed, 22 insertions(+), 61 deletions(-)
delete mode 100644 clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h b/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h
deleted file mode 100644
index bbba89cb7e3fd..0000000000000
--- a/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h
+++ /dev/null
@@ -1,52 +0,0 @@
-//===- CIRAttrVisitor.h - Visitor for CIR attributes ------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the CirAttrVisitor interface.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
-#define LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
-
-#include "clang/CIR/Dialect/IR/CIRAttrs.h"
-
-namespace cir {
-
-template <typename ImplClass, typename RetTy> class CirAttrVisitor {
-public:
- // FIXME: Create a TableGen list to automatically handle new attributes
- RetTy visit(mlir::Attribute attr) {
- if (const auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr))
- return getImpl().visitCirIntAttr(intAttr);
- if (const auto fltAttr = mlir::dyn_cast<cir::FPAttr>(attr))
- return getImpl().visitCirFPAttr(fltAttr);
- if (const auto ptrAttr = mlir::dyn_cast<cir::ConstPtrAttr>(attr))
- return getImpl().visitCirConstPtrAttr(ptrAttr);
- llvm_unreachable("unhandled attribute type");
- }
-
- // If the implementation chooses not to implement a certain visit
- // method, fall back to the parent.
- RetTy visitCirIntAttr(cir::IntAttr attr) {
- return getImpl().visitCirAttr(attr);
- }
- RetTy visitCirFPAttr(cir::FPAttr attr) {
- return getImpl().visitCirAttr(attr);
- }
- RetTy visitCirConstPtrAttr(cir::ConstPtrAttr attr) {
- return getImpl().visitCirAttr(attr);
- }
-
- RetTy visitCirAttr(mlir::Attribute attr) { return RetTy(); }
-
- ImplClass &getImpl() { return *static_cast<ImplClass *>(this); }
-};
-
-} // namespace cir
-
-#endif // LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index ba7fab2865116..7bf4b5fd27b61 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -24,10 +24,10 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "clang/CIR/Dialect/IR/CIRAttrVisitor.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/MissingFeatures.h"
#include "clang/CIR/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/TimeProfiler.h"
@@ -37,7 +37,7 @@ using namespace llvm;
namespace cir {
namespace direct {
-class CIRAttrToValue : public CirAttrVisitor<CIRAttrToValue, mlir::Value> {
+class CIRAttrToValue {
public:
CIRAttrToValue(mlir::Operation *parentOp,
mlir::ConversionPatternRewriter &rewriter,
@@ -46,19 +46,26 @@ class CIRAttrToValue : public CirAttrVisitor<CIRAttrToValue, mlir::Value> {
mlir::Value lowerCirAttrAsValue(mlir::Attribute attr) { return visit(attr); }
- mlir::Value visitCirIntAttr(cir::IntAttr intAttr) {
+ mlir::Value visit(mlir::Attribute attr) {
+ return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
+ .Case<cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr>(
+ [&](auto attrT) { return visitCirAttr(attrT); })
+ .Default([&](auto attrT) { return mlir::Value(); });
+ }
+
+ mlir::Value visitCirAttr(cir::IntAttr intAttr) {
mlir::Location loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::ConstantOp>(
loc, converter->convertType(intAttr.getType()), intAttr.getValue());
}
- mlir::Value visitCirFPAttr(cir::FPAttr fltAttr) {
+ mlir::Value visitCirAttr(cir::FPAttr fltAttr) {
mlir::Location loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::ConstantOp>(
loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
}
- mlir::Value visitCirConstPtrAttr(cir::ConstPtrAttr ptrAttr) {
+ mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr) {
mlir::Location loc = parentOp->getLoc();
if (ptrAttr.isNullValue()) {
return rewriter.create<mlir::LLVM::ZeroOp>(
@@ -81,8 +88,7 @@ class CIRAttrToValue : public CirAttrVisitor<CIRAttrToValue, mlir::Value> {
// This class handles rewriting initializer attributes for types that do not
// require region initialization.
-class GlobalInitAttrRewriter
- : public CirAttrVisitor<GlobalInitAttrRewriter, mlir::Attribute> {
+class GlobalInitAttrRewriter {
public:
GlobalInitAttrRewriter(mlir::Type type,
mlir::ConversionPatternRewriter &rewriter)
@@ -90,10 +96,17 @@ class GlobalInitAttrRewriter
mlir::Attribute rewriteInitAttr(mlir::Attribute attr) { return visit(attr); }
- mlir::Attribute visitCirIntAttr(cir::IntAttr attr) {
+ mlir::Attribute visit(mlir::Attribute attr) {
+ return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr)
+ .Case<cir::IntAttr, cir::FPAttr>(
+ [&](auto attrT) { return visitCirAttr(attrT); })
+ .Default([&](auto attrT) { return mlir::Attribute(); });
+ }
+
+ mlir::Attribute visitCirAttr(cir::IntAttr attr) {
return rewriter.getIntegerAttr(llvmType, attr.getValue());
}
- mlir::Attribute visitCirFPAttr(cir::FPAttr attr) {
+ mlir::Attribute visitCirAttr(cir::FPAttr attr) {
return rewriter.getFloatAttr(llvmType, attr.getValue());
}
>From cba15182d96bbdc28009501256fa93951089e70e Mon Sep 17 00:00:00 2001
From: Andy Kaylor <akaylor at nvidia.com>
Date: Fri, 28 Feb 2025 11:09:40 -0800
Subject: [PATCH 2/3] Re-align lowering code with incubator implementation
---
.../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 109 ++++++++++--------
.../CIR/Lowering/DirectToLLVM/LowerToLLVM.h | 2 -
2 files changed, 59 insertions(+), 52 deletions(-)
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 7bf4b5fd27b61..5d083efcdda6f 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -44,8 +44,6 @@ class CIRAttrToValue {
const mlir::TypeConverter *converter)
: parentOp(parentOp), rewriter(rewriter), converter(converter) {}
- mlir::Value lowerCirAttrAsValue(mlir::Attribute attr) { return visit(attr); }
-
mlir::Value visit(mlir::Attribute attr) {
return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
.Case<cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr>(
@@ -53,32 +51,9 @@ class CIRAttrToValue {
.Default([&](auto attrT) { return mlir::Value(); });
}
- mlir::Value visitCirAttr(cir::IntAttr intAttr) {
- mlir::Location loc = parentOp->getLoc();
- return rewriter.create<mlir::LLVM::ConstantOp>(
- loc, converter->convertType(intAttr.getType()), intAttr.getValue());
- }
-
- mlir::Value visitCirAttr(cir::FPAttr fltAttr) {
- mlir::Location loc = parentOp->getLoc();
- return rewriter.create<mlir::LLVM::ConstantOp>(
- loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
- }
-
- mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr) {
- mlir::Location loc = parentOp->getLoc();
- if (ptrAttr.isNullValue()) {
- return rewriter.create<mlir::LLVM::ZeroOp>(
- loc, converter->convertType(ptrAttr.getType()));
- }
- mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>());
- mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>(
- loc,
- rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())),
- ptrAttr.getValue().getInt());
- return rewriter.create<mlir::LLVM::IntToPtrOp>(
- loc, converter->convertType(ptrAttr.getType()), ptrVal);
- }
+ mlir::Value visitCirAttr(cir::IntAttr intAttr);
+ mlir::Value visitCirAttr(cir::FPAttr fltAttr);
+ mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
private:
mlir::Operation *parentOp;
@@ -86,6 +61,35 @@ class CIRAttrToValue {
const mlir::TypeConverter *converter;
};
+/// IntAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) {
+ mlir::Location loc = parentOp->getLoc();
+ return rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, converter->convertType(intAttr.getType()), intAttr.getValue());
+}
+
+/// ConstPtrAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) {
+ mlir::Location loc = parentOp->getLoc();
+ if (ptrAttr.isNullValue()) {
+ return rewriter.create<mlir::LLVM::ZeroOp>(
+ loc, converter->convertType(ptrAttr.getType()));
+ }
+ mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>());
+ mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())),
+ ptrAttr.getValue().getInt());
+ return rewriter.create<mlir::LLVM::IntToPtrOp>(
+ loc, converter->convertType(ptrAttr.getType()), ptrVal);
+}
+
+/// FPAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) {
+ mlir::Location loc = parentOp->getLoc();
+ return rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
+}
+
// This class handles rewriting initializer attributes for types that do not
// require region initialization.
class GlobalInitAttrRewriter {
@@ -94,8 +98,6 @@ class GlobalInitAttrRewriter {
mlir::ConversionPatternRewriter &rewriter)
: llvmType(type), rewriter(rewriter) {}
- mlir::Attribute rewriteInitAttr(mlir::Attribute attr) { return visit(attr); }
-
mlir::Attribute visit(mlir::Attribute attr) {
return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr)
.Case<cir::IntAttr, cir::FPAttr>(
@@ -137,12 +139,6 @@ struct ConvertCIRToLLVMPass
StringRef getArgument() const override { return "cir-flat-to-llvm"; }
};
-bool CIRToLLVMGlobalOpLowering::attrRequiresRegionInitialization(
- mlir::Attribute attr) const {
- // There will be more cases added later.
- return isa<cir::ConstPtrAttr>(attr);
-}
-
/// Replace CIR global with a region initialized LLVM global and update
/// insertion point to the end of the initializer block.
void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp(
@@ -189,8 +185,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
// to the appropriate value.
const mlir::Location loc = op.getLoc();
setupRegionInitializedLLVMGlobalOp(op, rewriter);
- CIRAttrToValue attrVisitor(op, rewriter, typeConverter);
- mlir::Value value = attrVisitor.lowerCirAttrAsValue(init);
+ CIRAttrToValue valueConverter(op, rewriter, typeConverter);
+ mlir::Value value = valueConverter.visit(init);
rewriter.create<mlir::LLVM::ReturnOp>(loc, value);
return mlir::success();
}
@@ -201,12 +197,6 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
std::optional<mlir::Attribute> init = op.getInitialValue();
- // If we have an initializer and it requires region initialization, handle
- // that separately
- if (init.has_value() && attrRequiresRegionInitialization(init.value())) {
- return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter);
- }
-
// Fetch required values to create LLVM op.
const mlir::Type cirSymType = op.getSymType();
@@ -231,12 +221,31 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
SmallVector<mlir::NamedAttribute> attributes;
if (init.has_value()) {
- GlobalInitAttrRewriter initRewriter(llvmType, rewriter);
- init = initRewriter.rewriteInitAttr(init.value());
- // If initRewriter returned a null attribute, init will have a value but
- // the value will be null. If that happens, initRewriter didn't handle the
- // attribute type. It probably needs to be added to GlobalInitAttrRewriter.
- if (!init.value()) {
+ if (mlir::isa<cir::FPAttr, cir::IntAttr>(init.value())) {
+ // If a directly equivalent attribute is available, use it.
+ init =
+ llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init.value())
+ .Case<cir::FPAttr>([&](cir::FPAttr attr) {
+ return rewriter.getFloatAttr(llvmType, attr.getValue());
+ })
+ .Case<cir::IntAttr>([&](cir::IntAttr attr) {
+ return rewriter.getIntegerAttr(llvmType, attr.getValue());
+ })
+ .Default([&](mlir::Attribute attr) { return mlir::Attribute(); });
+ // If initRewriter returned a null attribute, init will have a value but
+ // the value will be null.
+ if (!init.value()) {
+ op.emitError() << "unsupported initializer '" << init.value() << "'";
+ return mlir::failure();
+ }
+ } else if (mlir::isa<cir::ConstPtrAttr>(init.value())) {
+ // TODO(cir): once LLVM's dialect has proper equivalent attributes this
+ // should be updated. For now, we use a custom op to initialize globals
+ // to the appropriate value.
+ return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter);
+ } else {
+ // We will only get here if new initializer types are added and this
+ // code is not updated to handle them.
op.emitError() << "unsupported initializer '" << init.value() << "'";
return mlir::failure();
}
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index b3366c1fb9337..d1109bb7e1c08 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -36,8 +36,6 @@ class CIRToLLVMGlobalOpLowering
mlir::ConversionPatternRewriter &rewriter) const override;
private:
- bool attrRequiresRegionInitialization(mlir::Attribute attr) const;
-
mlir::LogicalResult matchAndRewriteRegionInitializedGlobal(
cir::GlobalOp op, mlir::Attribute init,
mlir::ConversionPatternRewriter &rewriter) const;
>From c76fb372cdae335a182f60f8f9c6dea207e8bd1d Mon Sep 17 00:00:00 2001
From: Andy Kaylor <akaylor at nvidia.com>
Date: Fri, 28 Feb 2025 13:34:57 -0800
Subject: [PATCH 3/3] Restore use of GlobalInitAttrRewriter
---
.../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 16 +++++-----------
1 file changed, 5 insertions(+), 11 deletions(-)
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 5d083efcdda6f..6f7cae8fa7fa3 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -222,18 +222,12 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
if (init.has_value()) {
if (mlir::isa<cir::FPAttr, cir::IntAttr>(init.value())) {
- // If a directly equivalent attribute is available, use it.
- init =
- llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init.value())
- .Case<cir::FPAttr>([&](cir::FPAttr attr) {
- return rewriter.getFloatAttr(llvmType, attr.getValue());
- })
- .Case<cir::IntAttr>([&](cir::IntAttr attr) {
- return rewriter.getIntegerAttr(llvmType, attr.getValue());
- })
- .Default([&](mlir::Attribute attr) { return mlir::Attribute(); });
+ GlobalInitAttrRewriter initRewriter(llvmType, rewriter);
+ init = initRewriter.visit(init.value());
// If initRewriter returned a null attribute, init will have a value but
- // the value will be null.
+ // the value will be null. If that happens, initRewriter didn't handle the
+ // attribute type. It probably needs to be added to
+ // GlobalInitAttrRewriter.
if (!init.value()) {
op.emitError() << "unsupported initializer '" << init.value() << "'";
return mlir::failure();
More information about the cfe-commits
mailing list