[clang] [CIR] Realign CIR-to-LLVM IR lowering code with incubator (PR #129293)

Andy Kaylor via cfe-commits cfe-commits at lists.llvm.org
Fri Feb 28 13:36:12 PST 2025


https://github.com/andykaylor updated https://github.com/llvm/llvm-project/pull/129293

>From bb41af68d0d0f66c5610c69d6deb8a615d644fe5 Mon Sep 17 00:00:00 2001
From: Andy Kaylor <akaylor at nvidia.com>
Date: Fri, 28 Feb 2025 10:54:09 -0800
Subject: [PATCH 1/3] [CIR] Replace CIRAttrVisitor with TypeSwitch

We previously discussed having an mlir-tblgen utility to complete the
CIRAttrVisitor implementation with all support attribute types, but
when I proposed an implementation to do this, a reviewer suggested
using TypeSwitch instead, and I have done that in the incubator.

See https://github.com/llvm/llvm-project/pull/126332

This change brings the TypeSwitch implementation into the upstream repo
to replace the visitor class.
---
 .../clang/CIR/Dialect/IR/CIRAttrVisitor.h     | 52 -------------------
 .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 31 +++++++----
 2 files changed, 22 insertions(+), 61 deletions(-)
 delete mode 100644 clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h

diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h b/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h
deleted file mode 100644
index bbba89cb7e3fd..0000000000000
--- a/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h
+++ /dev/null
@@ -1,52 +0,0 @@
-//===- CIRAttrVisitor.h - Visitor for CIR attributes ------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the CirAttrVisitor interface.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
-#define LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
-
-#include "clang/CIR/Dialect/IR/CIRAttrs.h"
-
-namespace cir {
-
-template <typename ImplClass, typename RetTy> class CirAttrVisitor {
-public:
-  // FIXME: Create a TableGen list to automatically handle new attributes
-  RetTy visit(mlir::Attribute attr) {
-    if (const auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr))
-      return getImpl().visitCirIntAttr(intAttr);
-    if (const auto fltAttr = mlir::dyn_cast<cir::FPAttr>(attr))
-      return getImpl().visitCirFPAttr(fltAttr);
-    if (const auto ptrAttr = mlir::dyn_cast<cir::ConstPtrAttr>(attr))
-      return getImpl().visitCirConstPtrAttr(ptrAttr);
-    llvm_unreachable("unhandled attribute type");
-  }
-
-  // If the implementation chooses not to implement a certain visit
-  // method, fall back to the parent.
-  RetTy visitCirIntAttr(cir::IntAttr attr) {
-    return getImpl().visitCirAttr(attr);
-  }
-  RetTy visitCirFPAttr(cir::FPAttr attr) {
-    return getImpl().visitCirAttr(attr);
-  }
-  RetTy visitCirConstPtrAttr(cir::ConstPtrAttr attr) {
-    return getImpl().visitCirAttr(attr);
-  }
-
-  RetTy visitCirAttr(mlir::Attribute attr) { return RetTy(); }
-
-  ImplClass &getImpl() { return *static_cast<ImplClass *>(this); }
-};
-
-} // namespace cir
-
-#endif // LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index ba7fab2865116..7bf4b5fd27b61 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -24,10 +24,10 @@
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Export.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "clang/CIR/Dialect/IR/CIRAttrVisitor.h"
 #include "clang/CIR/Dialect/IR/CIRDialect.h"
 #include "clang/CIR/MissingFeatures.h"
 #include "clang/CIR/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/TimeProfiler.h"
 
@@ -37,7 +37,7 @@ using namespace llvm;
 namespace cir {
 namespace direct {
 
-class CIRAttrToValue : public CirAttrVisitor<CIRAttrToValue, mlir::Value> {
+class CIRAttrToValue {
 public:
   CIRAttrToValue(mlir::Operation *parentOp,
                  mlir::ConversionPatternRewriter &rewriter,
@@ -46,19 +46,26 @@ class CIRAttrToValue : public CirAttrVisitor<CIRAttrToValue, mlir::Value> {
 
   mlir::Value lowerCirAttrAsValue(mlir::Attribute attr) { return visit(attr); }
 
-  mlir::Value visitCirIntAttr(cir::IntAttr intAttr) {
+  mlir::Value visit(mlir::Attribute attr) {
+    return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
+        .Case<cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr>(
+            [&](auto attrT) { return visitCirAttr(attrT); })
+        .Default([&](auto attrT) { return mlir::Value(); });
+  }
+
+  mlir::Value visitCirAttr(cir::IntAttr intAttr) {
     mlir::Location loc = parentOp->getLoc();
     return rewriter.create<mlir::LLVM::ConstantOp>(
         loc, converter->convertType(intAttr.getType()), intAttr.getValue());
   }
 
-  mlir::Value visitCirFPAttr(cir::FPAttr fltAttr) {
+  mlir::Value visitCirAttr(cir::FPAttr fltAttr) {
     mlir::Location loc = parentOp->getLoc();
     return rewriter.create<mlir::LLVM::ConstantOp>(
         loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
   }
 
-  mlir::Value visitCirConstPtrAttr(cir::ConstPtrAttr ptrAttr) {
+  mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr) {
     mlir::Location loc = parentOp->getLoc();
     if (ptrAttr.isNullValue()) {
       return rewriter.create<mlir::LLVM::ZeroOp>(
@@ -81,8 +88,7 @@ class CIRAttrToValue : public CirAttrVisitor<CIRAttrToValue, mlir::Value> {
 
 // This class handles rewriting initializer attributes for types that do not
 // require region initialization.
-class GlobalInitAttrRewriter
-    : public CirAttrVisitor<GlobalInitAttrRewriter, mlir::Attribute> {
+class GlobalInitAttrRewriter {
 public:
   GlobalInitAttrRewriter(mlir::Type type,
                          mlir::ConversionPatternRewriter &rewriter)
@@ -90,10 +96,17 @@ class GlobalInitAttrRewriter
 
   mlir::Attribute rewriteInitAttr(mlir::Attribute attr) { return visit(attr); }
 
-  mlir::Attribute visitCirIntAttr(cir::IntAttr attr) {
+  mlir::Attribute visit(mlir::Attribute attr) {
+    return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr)
+        .Case<cir::IntAttr, cir::FPAttr>(
+            [&](auto attrT) { return visitCirAttr(attrT); })
+        .Default([&](auto attrT) { return mlir::Attribute(); });
+  }
+
+  mlir::Attribute visitCirAttr(cir::IntAttr attr) {
     return rewriter.getIntegerAttr(llvmType, attr.getValue());
   }
-  mlir::Attribute visitCirFPAttr(cir::FPAttr attr) {
+  mlir::Attribute visitCirAttr(cir::FPAttr attr) {
     return rewriter.getFloatAttr(llvmType, attr.getValue());
   }
 

>From cba15182d96bbdc28009501256fa93951089e70e Mon Sep 17 00:00:00 2001
From: Andy Kaylor <akaylor at nvidia.com>
Date: Fri, 28 Feb 2025 11:09:40 -0800
Subject: [PATCH 2/3] Re-align lowering code with incubator implementation

---
 .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 109 ++++++++++--------
 .../CIR/Lowering/DirectToLLVM/LowerToLLVM.h   |   2 -
 2 files changed, 59 insertions(+), 52 deletions(-)

diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 7bf4b5fd27b61..5d083efcdda6f 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -44,8 +44,6 @@ class CIRAttrToValue {
                  const mlir::TypeConverter *converter)
       : parentOp(parentOp), rewriter(rewriter), converter(converter) {}
 
-  mlir::Value lowerCirAttrAsValue(mlir::Attribute attr) { return visit(attr); }
-
   mlir::Value visit(mlir::Attribute attr) {
     return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
         .Case<cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr>(
@@ -53,32 +51,9 @@ class CIRAttrToValue {
         .Default([&](auto attrT) { return mlir::Value(); });
   }
 
-  mlir::Value visitCirAttr(cir::IntAttr intAttr) {
-    mlir::Location loc = parentOp->getLoc();
-    return rewriter.create<mlir::LLVM::ConstantOp>(
-        loc, converter->convertType(intAttr.getType()), intAttr.getValue());
-  }
-
-  mlir::Value visitCirAttr(cir::FPAttr fltAttr) {
-    mlir::Location loc = parentOp->getLoc();
-    return rewriter.create<mlir::LLVM::ConstantOp>(
-        loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
-  }
-
-  mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr) {
-    mlir::Location loc = parentOp->getLoc();
-    if (ptrAttr.isNullValue()) {
-      return rewriter.create<mlir::LLVM::ZeroOp>(
-          loc, converter->convertType(ptrAttr.getType()));
-    }
-    mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>());
-    mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>(
-        loc,
-        rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())),
-        ptrAttr.getValue().getInt());
-    return rewriter.create<mlir::LLVM::IntToPtrOp>(
-        loc, converter->convertType(ptrAttr.getType()), ptrVal);
-  }
+  mlir::Value visitCirAttr(cir::IntAttr intAttr);
+  mlir::Value visitCirAttr(cir::FPAttr fltAttr);
+  mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
 
 private:
   mlir::Operation *parentOp;
@@ -86,6 +61,35 @@ class CIRAttrToValue {
   const mlir::TypeConverter *converter;
 };
 
+/// IntAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) {
+  mlir::Location loc = parentOp->getLoc();
+  return rewriter.create<mlir::LLVM::ConstantOp>(
+      loc, converter->convertType(intAttr.getType()), intAttr.getValue());
+}
+
+/// ConstPtrAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) {
+  mlir::Location loc = parentOp->getLoc();
+  if (ptrAttr.isNullValue()) {
+    return rewriter.create<mlir::LLVM::ZeroOp>(
+        loc, converter->convertType(ptrAttr.getType()));
+  }
+  mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>());
+  mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>(
+      loc, rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())),
+      ptrAttr.getValue().getInt());
+  return rewriter.create<mlir::LLVM::IntToPtrOp>(
+      loc, converter->convertType(ptrAttr.getType()), ptrVal);
+}
+
+/// FPAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) {
+  mlir::Location loc = parentOp->getLoc();
+  return rewriter.create<mlir::LLVM::ConstantOp>(
+      loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
+}
+
 // This class handles rewriting initializer attributes for types that do not
 // require region initialization.
 class GlobalInitAttrRewriter {
@@ -94,8 +98,6 @@ class GlobalInitAttrRewriter {
                          mlir::ConversionPatternRewriter &rewriter)
       : llvmType(type), rewriter(rewriter) {}
 
-  mlir::Attribute rewriteInitAttr(mlir::Attribute attr) { return visit(attr); }
-
   mlir::Attribute visit(mlir::Attribute attr) {
     return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr)
         .Case<cir::IntAttr, cir::FPAttr>(
@@ -137,12 +139,6 @@ struct ConvertCIRToLLVMPass
   StringRef getArgument() const override { return "cir-flat-to-llvm"; }
 };
 
-bool CIRToLLVMGlobalOpLowering::attrRequiresRegionInitialization(
-    mlir::Attribute attr) const {
-  // There will be more cases added later.
-  return isa<cir::ConstPtrAttr>(attr);
-}
-
 /// Replace CIR global with a region initialized LLVM global and update
 /// insertion point to the end of the initializer block.
 void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp(
@@ -189,8 +185,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
   // to the appropriate value.
   const mlir::Location loc = op.getLoc();
   setupRegionInitializedLLVMGlobalOp(op, rewriter);
-  CIRAttrToValue attrVisitor(op, rewriter, typeConverter);
-  mlir::Value value = attrVisitor.lowerCirAttrAsValue(init);
+  CIRAttrToValue valueConverter(op, rewriter, typeConverter);
+  mlir::Value value = valueConverter.visit(init);
   rewriter.create<mlir::LLVM::ReturnOp>(loc, value);
   return mlir::success();
 }
@@ -201,12 +197,6 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
 
   std::optional<mlir::Attribute> init = op.getInitialValue();
 
-  // If we have an initializer and it requires region initialization, handle
-  // that separately
-  if (init.has_value() && attrRequiresRegionInitialization(init.value())) {
-    return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter);
-  }
-
   // Fetch required values to create LLVM op.
   const mlir::Type cirSymType = op.getSymType();
 
@@ -231,12 +221,31 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
   SmallVector<mlir::NamedAttribute> attributes;
 
   if (init.has_value()) {
-    GlobalInitAttrRewriter initRewriter(llvmType, rewriter);
-    init = initRewriter.rewriteInitAttr(init.value());
-    // If initRewriter returned a null attribute, init will have a value but
-    // the value will be null. If that happens, initRewriter didn't handle the
-    // attribute type. It probably needs to be added to GlobalInitAttrRewriter.
-    if (!init.value()) {
+    if (mlir::isa<cir::FPAttr, cir::IntAttr>(init.value())) {
+      // If a directly equivalent attribute is available, use it.
+      init =
+          llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init.value())
+              .Case<cir::FPAttr>([&](cir::FPAttr attr) {
+                return rewriter.getFloatAttr(llvmType, attr.getValue());
+              })
+              .Case<cir::IntAttr>([&](cir::IntAttr attr) {
+                return rewriter.getIntegerAttr(llvmType, attr.getValue());
+              })
+              .Default([&](mlir::Attribute attr) { return mlir::Attribute(); });
+      // If initRewriter returned a null attribute, init will have a value but
+      // the value will be null.
+      if (!init.value()) {
+        op.emitError() << "unsupported initializer '" << init.value() << "'";
+        return mlir::failure();
+      }
+    } else if (mlir::isa<cir::ConstPtrAttr>(init.value())) {
+      // TODO(cir): once LLVM's dialect has proper equivalent attributes this
+      // should be updated. For now, we use a custom op to initialize globals
+      // to the appropriate value.
+      return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter);
+    } else {
+      // We will only get here if new initializer types are added and this
+      // code is not updated to handle them.
       op.emitError() << "unsupported initializer '" << init.value() << "'";
       return mlir::failure();
     }
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index b3366c1fb9337..d1109bb7e1c08 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -36,8 +36,6 @@ class CIRToLLVMGlobalOpLowering
                   mlir::ConversionPatternRewriter &rewriter) const override;
 
 private:
-  bool attrRequiresRegionInitialization(mlir::Attribute attr) const;
-
   mlir::LogicalResult matchAndRewriteRegionInitializedGlobal(
       cir::GlobalOp op, mlir::Attribute init,
       mlir::ConversionPatternRewriter &rewriter) const;

>From c76fb372cdae335a182f60f8f9c6dea207e8bd1d Mon Sep 17 00:00:00 2001
From: Andy Kaylor <akaylor at nvidia.com>
Date: Fri, 28 Feb 2025 13:34:57 -0800
Subject: [PATCH 3/3] Restore use of GlobalInitAttrRewriter

---
 .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp    | 16 +++++-----------
 1 file changed, 5 insertions(+), 11 deletions(-)

diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 5d083efcdda6f..6f7cae8fa7fa3 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -222,18 +222,12 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
 
   if (init.has_value()) {
     if (mlir::isa<cir::FPAttr, cir::IntAttr>(init.value())) {
-      // If a directly equivalent attribute is available, use it.
-      init =
-          llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init.value())
-              .Case<cir::FPAttr>([&](cir::FPAttr attr) {
-                return rewriter.getFloatAttr(llvmType, attr.getValue());
-              })
-              .Case<cir::IntAttr>([&](cir::IntAttr attr) {
-                return rewriter.getIntegerAttr(llvmType, attr.getValue());
-              })
-              .Default([&](mlir::Attribute attr) { return mlir::Attribute(); });
+      GlobalInitAttrRewriter initRewriter(llvmType, rewriter);
+      init = initRewriter.visit(init.value());
       // If initRewriter returned a null attribute, init will have a value but
-      // the value will be null.
+      // the value will be null. If that happens, initRewriter didn't handle the
+      // attribute type. It probably needs to be added to
+      // GlobalInitAttrRewriter.
       if (!init.value()) {
         op.emitError() << "unsupported initializer '" << init.value() << "'";
         return mlir::failure();



More information about the cfe-commits mailing list