[Mlir-commits] [mlir] [MLIR][LLVM] More on CG Profile: support null function entries (PR #137269)

Bruno Cardoso Lopes llvmlistbot at llvm.org
Mon Apr 28 11:12:22 PDT 2025


https://github.com/bcardosolopes updated https://github.com/llvm/llvm-project/pull/137269

>From d8e04476124e2000aba8a5ebdbf1da785804644b Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Thu, 24 Apr 2025 15:49:25 -0700
Subject: [PATCH 1/3] [MLIR][LLVM] More on CG Profile: support null function
 entries

---
 .../include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 15 +++++++++++----
 .../Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp  | 17 ++++++++++++-----
 mlir/lib/Target/LLVMIR/ModuleImport.cpp         | 11 ++++++++---
 mlir/test/Dialect/LLVMIR/module-roundtrip.mlir  |  4 ++--
 mlir/test/Target/LLVMIR/Import/module-flags.ll  |  6 +++---
 mlir/test/Target/LLVMIR/llvmir.mlir             |  4 ++--
 6 files changed, 38 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index cfa6ebf3e6775..7a730be974df4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1370,10 +1370,17 @@ def ModuleFlagCGProfileEntryAttr
           ]>]
     ```
   }];
-  let parameters = (ins "FlatSymbolRefAttr":$from,
-                        "FlatSymbolRefAttr":$to,
-                        "uint64_t":$count);
-  let assemblyFormat = "`<` struct(params) `>`";
+  let parameters = (
+    ins OptionalParameter<"std::optional<FlatSymbolRefAttr>">:$from,
+        OptionalParameter<"std::optional<FlatSymbolRefAttr>">:$to,
+        "uint64_t":$count);
+  let assemblyFormat = [{
+    `<`
+      `from` `=` ($from^) : (`null`)? `,`
+      `to` `=` ($to^) : (`null`)? `,`
+      `count` `=` $count
+    `>`
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index e4b2ebee0d9d3..e8e29aa059285 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -281,12 +281,19 @@ convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr,
 
   if (key == LLVMDialect::getModuleFlagKeyCGProfileName()) {
     for (auto entry : arrayAttr.getAsRange<ModuleFlagCGProfileEntryAttr>()) {
-      llvm::Function *fromFn =
-          moduleTranslation.lookupFunction(entry.getFrom().getValue());
-      llvm::Function *toFn =
-          moduleTranslation.lookupFunction(entry.getTo().getValue());
+      llvm::Metadata *fromMetadata =
+          entry.getFrom()
+              ? llvm::ValueAsMetadata::get(moduleTranslation.lookupFunction(
+                    entry.getFrom()->getValue()))
+              : nullptr;
+      llvm::Metadata *toMetadata =
+          entry.getTo()
+              ? llvm::ValueAsMetadata::get(
+                    moduleTranslation.lookupFunction(entry.getTo()->getValue()))
+              : nullptr;
+
       llvm::Metadata *vals[] = {
-          llvm::ValueAsMetadata::get(fromFn), llvm::ValueAsMetadata::get(toFn),
+          fromMetadata, toMetadata,
           mdb.createConstant(llvm::ConstantInt::get(
               llvm::Type::getInt64Ty(context), entry.getCount()))};
       nodes.push_back(llvm::MDNode::get(context, vals));
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index d73c84af48b25..e25de9d18963b 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -521,8 +521,12 @@ void ModuleImport::addDebugIntrinsic(llvm::CallInst *intrinsic) {
 
 static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule,
                                                  llvm::MDTuple *mdTuple) {
-  auto getFunctionSymbol = [&](const llvm::MDOperand &funcMDO) {
-    auto *f = cast<llvm::ValueAsMetadata>(funcMDO);
+  auto getFunctionSymbol =
+      [&](const llvm::MDOperand &funcMDO) -> std::optional<FlatSymbolRefAttr> {
+    auto *f = dyn_cast_or_null<llvm::ValueAsMetadata>(funcMDO);
+    // nullptr is a valid value for the function pointer.
+    if (!f)
+      return std::nullopt;
     auto *llvmFn = cast<llvm::Function>(f->getValue()->stripPointerCasts());
     return FlatSymbolRefAttr::get(mlirModule->getContext(), llvmFn->getName());
   };
@@ -570,7 +574,8 @@ LogicalResult ModuleImport::convertModuleFlagsMetadata() {
 
     if (!valAttr) {
       emitWarning(mlirModule.getLoc())
-          << "unsupported module flag value: " << diagMD(val, llvmModule.get());
+          << "unsupported module flag value for key '" << key->getString()
+          << "' : " << diagMD(val, llvmModule.get());
       continue;
     }
 
diff --git a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
index b45c61ff10b74..1450fada8a990 100644
--- a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
@@ -9,7 +9,7 @@ module {
                      #llvm.mlir.module_flag<override, "probe-stack", "inline-asm">,
                      #llvm.mlir.module_flag<append, "CG Profile", [
                        #llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
-                       #llvm.cgprofile_entry<from = @from, to = @from, count = 222>,
+                       #llvm.cgprofile_entry<from = @from, to = null, count = 222>,
                        #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
                     ]>]
 }
@@ -23,6 +23,6 @@ module {
 // CHECK-SAME: #llvm.mlir.module_flag<override, "probe-stack", "inline-asm">,
 // CHECK-SAME: #llvm.mlir.module_flag<append, "CG Profile", [
 // CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
-// CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = @from, count = 222>,
+// CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = null, count = 222>,
 // CHECK-SAME: #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
 // CHECK-SAME: ]>]
diff --git a/mlir/test/Target/LLVMIR/Import/module-flags.ll b/mlir/test/Target/LLVMIR/Import/module-flags.ll
index 31ab8afb7ed83..fc963f0f6fa99 100644
--- a/mlir/test/Target/LLVMIR/Import/module-flags.ll
+++ b/mlir/test/Target/LLVMIR/Import/module-flags.ll
@@ -19,7 +19,7 @@
 ; CHECK-SAME: #llvm.mlir.module_flag<override, "probe-stack", "inline-asm">]
 
 ; // -----
-; expected-warning at -2 {{unsupported module flag value: !4 = !{!"foo", i32 1}}}
+; expected-warning at -2 {{unsupported module flag value for key 'qux' : !4 = !{!"foo", i32 1}}}
 !10 = !{ i32 1, !"foo", i32 1 }
 !11 = !{ i32 4, !"bar", i32 37 }
 !12 = !{ i32 2, !"qux", i32 42 }
@@ -36,11 +36,11 @@ declare void @to()
 !20 = !{i32 5, !"CG Profile", !21}
 !21 = distinct !{!22, !23, !24}
 !22 = !{ptr @from, ptr @to, i64 222}
-!23 = !{ptr @from, ptr @from, i64 222}
+!23 = !{ptr @from, null, i64 222}
 !24 = !{ptr @to, ptr @from, i64 222}
 
 ; CHECK: llvm.module_flags [#llvm.mlir.module_flag<append, "CG Profile", [
 ; CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
-; CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = @from, count = 222>,
+; CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = null, count = 222>,
 ; CHECK-SAME: #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
 ; CHECK-SAME: ]>]
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index ffd1071e872ec..fab16d560a9e4 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2840,7 +2840,7 @@ module {
 
 llvm.module_flags [#llvm.mlir.module_flag<append, "CG Profile", [
   #llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
-  #llvm.cgprofile_entry<from = @from, to = @from, count = 222>,
+  #llvm.cgprofile_entry<from = @from, to = null, count = 222>,
   #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
 ]>]
 llvm.func @from(i32)
@@ -2851,7 +2851,7 @@ llvm.func @to()
 // CHECK: ![[#CGPROF]] = !{i32 5, !"CG Profile", ![[#LIST:]]}
 // CHECK: ![[#LIST]] = distinct !{![[#ENTRY_A:]], ![[#ENTRY_B:]], ![[#ENTRY_C:]]}
 // CHECK: ![[#ENTRY_A]] = !{ptr @from, ptr @to, i64 222}
-// CHECK: ![[#ENTRY_B]] = !{ptr @from, ptr @from, i64 222}
+// CHECK: ![[#ENTRY_B]] = !{ptr @from, null, i64 222}
 // CHECK: ![[#ENTRY_C]] = !{ptr @to, ptr @from, i64 222}
 // CHECK: ![[#DBG]] = !{i32 2, !"Debug Info Version", i32 3}
 

>From 70ba174c7aa31e92ce05d6fb77759d2f159450d2 Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Fri, 25 Apr 2025 14:50:54 -0700
Subject: [PATCH 2/3] use cast_or_null

---
 mlir/lib/Target/LLVMIR/ModuleImport.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index e25de9d18963b..9ade36524070c 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -523,7 +523,7 @@ static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule,
                                                  llvm::MDTuple *mdTuple) {
   auto getFunctionSymbol =
       [&](const llvm::MDOperand &funcMDO) -> std::optional<FlatSymbolRefAttr> {
-    auto *f = dyn_cast_or_null<llvm::ValueAsMetadata>(funcMDO);
+    auto *f = cast_or_null<llvm::ValueAsMetadata>(funcMDO);
     // nullptr is a valid value for the function pointer.
     if (!f)
       return std::nullopt;

>From 95c9aacf8480c914d8e07d96cc531b8c1a225a76 Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Mon, 28 Apr 2025 11:04:18 -0700
Subject: [PATCH 3/3] omit null entries

---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       | 13 ++++--------
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        |  4 ++--
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       | 20 +++++++++++++------
 .../test/Dialect/LLVMIR/module-roundtrip.mlir |  4 ++--
 .../test/Target/LLVMIR/Import/module-flags.ll |  2 +-
 mlir/test/Target/LLVMIR/llvmir.mlir           |  2 +-
 6 files changed, 24 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 7a730be974df4..7d6d38ecad897 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1371,16 +1371,11 @@ def ModuleFlagCGProfileEntryAttr
     ```
   }];
   let parameters = (
-    ins OptionalParameter<"std::optional<FlatSymbolRefAttr>">:$from,
-        OptionalParameter<"std::optional<FlatSymbolRefAttr>">:$to,
+    ins OptionalParameter<"FlatSymbolRefAttr">:$from,
+        OptionalParameter<"FlatSymbolRefAttr">:$to,
         "uint64_t":$count);
-  let assemblyFormat = [{
-    `<`
-      `from` `=` ($from^) : (`null`)? `,`
-      `to` `=` ($to^) : (`null`)? `,`
-      `count` `=` $count
-    `>`
-  }];
+
+  let assemblyFormat = "`<` struct(params) `>`";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index e8e29aa059285..35dcde2a33d41 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -284,12 +284,12 @@ convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr,
       llvm::Metadata *fromMetadata =
           entry.getFrom()
               ? llvm::ValueAsMetadata::get(moduleTranslation.lookupFunction(
-                    entry.getFrom()->getValue()))
+                    entry.getFrom().getValue()))
               : nullptr;
       llvm::Metadata *toMetadata =
           entry.getTo()
               ? llvm::ValueAsMetadata::get(
-                    moduleTranslation.lookupFunction(entry.getTo()->getValue()))
+                    moduleTranslation.lookupFunction(entry.getTo().getValue()))
               : nullptr;
 
       llvm::Metadata *vals[] = {
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 9ade36524070c..0a3371c6b154d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -521,14 +521,14 @@ void ModuleImport::addDebugIntrinsic(llvm::CallInst *intrinsic) {
 
 static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule,
                                                  llvm::MDTuple *mdTuple) {
-  auto getFunctionSymbol =
-      [&](const llvm::MDOperand &funcMDO) -> std::optional<FlatSymbolRefAttr> {
+  auto getLLVMFunction =
+      [&](const llvm::MDOperand &funcMDO) -> llvm::Function * {
     auto *f = cast_or_null<llvm::ValueAsMetadata>(funcMDO);
     // nullptr is a valid value for the function pointer.
     if (!f)
-      return std::nullopt;
+      return nullptr;
     auto *llvmFn = cast<llvm::Function>(f->getValue()->stripPointerCasts());
-    return FlatSymbolRefAttr::get(mlirModule->getContext(), llvmFn->getName());
+    return llvmFn;
   };
 
   // Each tuple element becomes one ModuleFlagCGProfileEntryAttr.
@@ -539,9 +539,17 @@ static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule,
     llvm::Constant *llvmConstant =
         cast<llvm::ConstantAsMetadata>(cgEntry->getOperand(2))->getValue();
     uint64_t count = cast<llvm::ConstantInt>(llvmConstant)->getZExtValue();
+    auto *fromFn = getLLVMFunction(cgEntry->getOperand(0));
+    auto *toFn = getLLVMFunction(cgEntry->getOperand(1));
+    // FlatSymbolRefAttr::get(mlirModule->getContext(), llvmFn->getName());
     cgProfile.push_back(ModuleFlagCGProfileEntryAttr::get(
-        mlirModule->getContext(), getFunctionSymbol(cgEntry->getOperand(0)),
-        getFunctionSymbol(cgEntry->getOperand(1)), count));
+        mlirModule->getContext(),
+        fromFn ? FlatSymbolRefAttr::get(mlirModule->getContext(),
+                                        fromFn->getName())
+               : nullptr,
+        toFn ? FlatSymbolRefAttr::get(mlirModule->getContext(), toFn->getName())
+             : nullptr,
+        count));
   }
   return ArrayAttr::get(mlirModule->getContext(), cgProfile);
 }
diff --git a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
index 1450fada8a990..025d9b2287c42 100644
--- a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
@@ -9,7 +9,7 @@ module {
                      #llvm.mlir.module_flag<override, "probe-stack", "inline-asm">,
                      #llvm.mlir.module_flag<append, "CG Profile", [
                        #llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
-                       #llvm.cgprofile_entry<from = @from, to = null, count = 222>,
+                       #llvm.cgprofile_entry<from = @from, count = 222>,
                        #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
                     ]>]
 }
@@ -23,6 +23,6 @@ module {
 // CHECK-SAME: #llvm.mlir.module_flag<override, "probe-stack", "inline-asm">,
 // CHECK-SAME: #llvm.mlir.module_flag<append, "CG Profile", [
 // CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
-// CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = null, count = 222>,
+// CHECK-SAME: #llvm.cgprofile_entry<from = @from, count = 222>,
 // CHECK-SAME: #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
 // CHECK-SAME: ]>]
diff --git a/mlir/test/Target/LLVMIR/Import/module-flags.ll b/mlir/test/Target/LLVMIR/Import/module-flags.ll
index fc963f0f6fa99..09e708de0cc93 100644
--- a/mlir/test/Target/LLVMIR/Import/module-flags.ll
+++ b/mlir/test/Target/LLVMIR/Import/module-flags.ll
@@ -41,6 +41,6 @@ declare void @to()
 
 ; CHECK: llvm.module_flags [#llvm.mlir.module_flag<append, "CG Profile", [
 ; CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
-; CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = null, count = 222>,
+; CHECK-SAME: #llvm.cgprofile_entry<from = @from, count = 222>,
 ; CHECK-SAME: #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
 ; CHECK-SAME: ]>]
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index fab16d560a9e4..80778e4ca3be0 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2840,7 +2840,7 @@ module {
 
 llvm.module_flags [#llvm.mlir.module_flag<append, "CG Profile", [
   #llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
-  #llvm.cgprofile_entry<from = @from, to = null, count = 222>,
+  #llvm.cgprofile_entry<from = @from, count = 222>,
   #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
 ]>]
 llvm.func @from(i32)



More information about the Mlir-commits mailing list