[llvm] [LLVM][Parser] Check invalid overload suffix for intrinsics (PR #108315)

Rahul Joshi via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 11 18:00:01 PDT 2024


https://github.com/jurahul created https://github.com/llvm/llvm-project/pull/108315

None

>From ab5c25591715dde6580d38ec991e64f26031dbf4 Mon Sep 17 00:00:00 2001
From: Rahul Joshi <rjoshi at nvidia.com>
Date: Wed, 11 Sep 2024 17:34:45 -0700
Subject: [PATCH] [LLVM][Parser] Check invalid overload suffix for intrinsics

---
 llvm/include/llvm/AsmParser/LLParser.h        |  4 ++
 llvm/include/llvm/IR/Function.h               |  7 ++-
 llvm/lib/AsmParser/LLParser.cpp               | 55 ++++++++++++++-----
 llvm/lib/IR/Function.cpp                      | 16 ++++--
 .../Assembler/intrinsic-overload-error0.ll    | 12 ++++
 .../Assembler/intrinsic-overload-error1.ll    |  8 +++
 llvm/test/Assembler/intrinsic-overload.ll     | 16 ++++++
 7 files changed, 97 insertions(+), 21 deletions(-)
 create mode 100644 llvm/test/Assembler/intrinsic-overload-error0.ll
 create mode 100644 llvm/test/Assembler/intrinsic-overload-error1.ll
 create mode 100644 llvm/test/Assembler/intrinsic-overload.ll

diff --git a/llvm/include/llvm/AsmParser/LLParser.h b/llvm/include/llvm/AsmParser/LLParser.h
index 9576b935198dd4..1d72d395c9f8b4 100644
--- a/llvm/include/llvm/AsmParser/LLParser.h
+++ b/llvm/include/llvm/AsmParser/LLParser.h
@@ -171,6 +171,10 @@ namespace llvm {
     std::map<unsigned, std::vector<std::pair<GlobalValue::GUID *, LocTy>>>
         ForwardRefTypeIds;
 
+    // Locations for all call instructions that call an overloaded intrinsic
+    // (for accurate error reporting).
+    std::map<CallBase *, LocTy> OverloadedIntrinsicCallLocs;
+
     // Map of module ID to path.
     std::map<unsigned, StringRef> ModuleIdMap;
 
diff --git a/llvm/include/llvm/IR/Function.h b/llvm/include/llvm/IR/Function.h
index 866c68d15e4011..3d56c9086ef1a5 100644
--- a/llvm/include/llvm/IR/Function.h
+++ b/llvm/include/llvm/IR/Function.h
@@ -269,7 +269,12 @@ class LLVM_ABI Function : public GlobalObject, public ilist_node<Function> {
   /// getIntrinsicID() returns Intrinsic::not_intrinsic.
   bool isConstrainedFPIntrinsic() const;
 
-  static Intrinsic::ID lookupIntrinsicID(StringRef Name);
+  static std::pair<Intrinsic::ID, StringRef>
+  lookupIntrinsicIDAndSuffix(StringRef Name);
+
+  static Intrinsic::ID lookupIntrinsicID(StringRef Name) {
+    return lookupIntrinsicIDAndSuffix(Name).first;
+  }
 
   /// Update internal caches that depend on the function name (such as the
   /// intrinsic ID and libcall cache).
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 66edf1e5a29573..e3a062d3303016 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -338,7 +338,7 @@ bool LLParser::validateEndOfModule(bool UpgradeDebugInfo) {
 
   for (const auto &[Name, Info] : make_early_inc_range(ForwardRefVals)) {
     if (StringRef(Name).starts_with("llvm.")) {
-      Intrinsic::ID IID = Function::lookupIntrinsicID(Name);
+      auto [IID, Suffix] = Function::lookupIntrinsicIDAndSuffix(Name);
       if (IID == Intrinsic::not_intrinsic)
         // Don't do anything for unknown intrinsics.
         continue;
@@ -349,18 +349,37 @@ bool LLParser::validateEndOfModule(bool UpgradeDebugInfo) {
       //
       // Additionally, automatically add the required mangling suffix to the
       // intrinsic name. This means that we may replace a single forward
-      // declaration with multiple functions here.
+      // declaration with multiple functions here. If there is a suffix left
+      // over after matching the intrinsic name, it should match the mangling
+      // suffix.
       for (Use &U : make_early_inc_range(Info.first->uses())) {
         auto *CB = dyn_cast<CallBase>(U.getUser());
         if (!CB || !CB->isCallee(&U))
           return error(Info.second, "intrinsic can only be used as callee");
 
+        // Location for error reporting.
+        LocTy Loc;
+        auto II = OverloadedIntrinsicCallLocs.find(CB);
+        if (II != OverloadedIntrinsicCallLocs.end()) {
+          Loc = II->second;
+          OverloadedIntrinsicCallLocs.erase(II);
+        } else {
+          Loc = Info.second;
+        }
+
         SmallVector<Type *> OverloadTys;
         if (!Intrinsic::getIntrinsicSignature(IID, CB->getFunctionType(),
                                               OverloadTys))
-          return error(Info.second, "invalid intrinsic signature");
-
-        U.set(Intrinsic::getDeclaration(M, IID, OverloadTys));
+          return error(Loc, "invalid intrinsic signature");
+        Function *Intrinsic = Intrinsic::getDeclaration(M, IID, OverloadTys);
+        // Note: Suffix will be empty for non-overloaded intrinsics, so this
+        // check will always pass. For overloaded intrinsics that do not use
+        // a mangling suffix as well, the suffix will be empty and they will
+        // always pass this error check.
+        if (!Intrinsic->getName().ends_with(Suffix))
+          return error(Loc, "invalid intrinsic name, expected @" +
+                                Intrinsic->getName());
+        U.set(Intrinsic);
       }
 
       Info.first->eraseFromParent();
@@ -8079,16 +8098,24 @@ bool LLParser::parseCall(Instruction *&Inst, PerFunctionState &PFS,
     CI->setFastMathFlags(FMF);
   }
 
-  if (CalleeID.Kind == ValID::t_GlobalName &&
-      isOldDbgFormatIntrinsic(CalleeID.StrVal)) {
-    if (SeenNewDbgInfoFormat) {
-      CI->deleteValue();
-      return error(CallLoc, "llvm.dbg intrinsic should not appear in a module "
-                            "using non-intrinsic debug info");
+  if (CalleeID.Kind == ValID::t_GlobalName) {
+    if (StringRef(CalleeID.StrVal).starts_with("llvm.")) {
+      // If this is a call to an intrinsic, remember its location for better
+      // error reporting when overloaded intrinsic resolution fails.
+      OverloadedIntrinsicCallLocs[CI] = CalleeID.Loc;
+    }
+
+    if (isOldDbgFormatIntrinsic(CalleeID.StrVal)) {
+      if (SeenNewDbgInfoFormat) {
+        CI->deleteValue();
+        return error(CallLoc,
+                     "llvm.dbg intrinsic should not appear in a module "
+                     "using non-intrinsic debug info");
+      }
+      if (!SeenOldDbgInfoFormat)
+        M->setNewDbgInfoFormatFlag(false);
+      SeenOldDbgInfoFormat = true;
     }
-    if (!SeenOldDbgInfoFormat)
-      M->setNewDbgInfoFormatFlag(false);
-    SeenOldDbgInfoFormat = true;
   }
   CI->setAttributes(PAL);
   ForwardRefAttrGroups[CI] = FwdRefAttrGrps;
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 82ff4e1bc7f5c5..b160c5fc5be34b 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -974,13 +974,14 @@ static ArrayRef<const char *> findTargetSubtable(StringRef Name) {
   return ArrayRef(&IntrinsicNameTable[1] + TI.Offset, TI.Count);
 }
 
-/// This does the actual lookup of an intrinsic ID which
-/// matches the given function name.
-Intrinsic::ID Function::lookupIntrinsicID(StringRef Name) {
+/// This does the actual lookup of an intrinsic ID which matches the given
+/// function name.
+std::pair<Intrinsic::ID, StringRef>
+Function::lookupIntrinsicIDAndSuffix(StringRef Name) {
   ArrayRef<const char *> NameTable = findTargetSubtable(Name);
   int Idx = Intrinsic::lookupLLVMIntrinsicByName(NameTable, Name);
   if (Idx == -1)
-    return Intrinsic::not_intrinsic;
+    return {Intrinsic::not_intrinsic, ""};
 
   // Intrinsic IDs correspond to the location in IntrinsicNameTable, but we have
   // an index into a sub-table.
@@ -992,8 +993,11 @@ Intrinsic::ID Function::lookupIntrinsicID(StringRef Name) {
   const auto MatchSize = strlen(NameTable[Idx]);
   assert(Name.size() >= MatchSize && "Expected either exact or prefix match");
   bool IsExactMatch = Name.size() == MatchSize;
-  return IsExactMatch || Intrinsic::isOverloaded(ID) ? ID
-                                                     : Intrinsic::not_intrinsic;
+  if (IsExactMatch)
+    return {ID, ""};
+  if (Intrinsic::isOverloaded(ID))
+    return {ID, Name.drop_front(MatchSize)};
+  return {Intrinsic::not_intrinsic, ""};
 }
 
 void Function::updateAfterNameChange() {
diff --git a/llvm/test/Assembler/intrinsic-overload-error0.ll b/llvm/test/Assembler/intrinsic-overload-error0.ll
new file mode 100644
index 00000000000000..bb9d773b2bf049
--- /dev/null
+++ b/llvm/test/Assembler/intrinsic-overload-error0.ll
@@ -0,0 +1,12 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+; Check that intrinsic calls with mangling are error checked.
+; Mix good and bad calls to demonstrate that line number tracking is required
+; in the parser to report correct line number for the second (bad) call.
+define void @foo(float %a, i32 %b) {
+    %c = call i1 @llvm.is.constant.i32(i32 0)
+    ; CHECK: <stdin>:[[@LINE+1]]:18: error: invalid intrinsic name, expected @llvm.is.constant.f32
+    %d = call i1 @llvm.is.constant.i32(float %a)
+    %e = call i1 @llvm.is.constant.i1(i1 false)
+    ret void
+}
diff --git a/llvm/test/Assembler/intrinsic-overload-error1.ll b/llvm/test/Assembler/intrinsic-overload-error1.ll
new file mode 100644
index 00000000000000..851148a32eb4ee
--- /dev/null
+++ b/llvm/test/Assembler/intrinsic-overload-error1.ll
@@ -0,0 +1,8 @@
+; RUN: not llvm-as < %s 2>&1 | FileCheck %s
+
+; Check that intrinsic calls with mangling are error checked.
+define void @foo(float %a, i32 %b) {
+    ; CHECK: <stdin>:[[@LINE+1]]:18: error: invalid intrinsic name, expected @llvm.is.constant.f32
+    %c = call i1 @llvm.is.constant.badsuffix(float %a)
+    ret void
+}
diff --git a/llvm/test/Assembler/intrinsic-overload.ll b/llvm/test/Assembler/intrinsic-overload.ll
new file mode 100644
index 00000000000000..02fcdd1d5c55af
--- /dev/null
+++ b/llvm/test/Assembler/intrinsic-overload.ll
@@ -0,0 +1,16 @@
+; RUN: llvm-as < %s | llvm-dis | FileCheck %s
+
+; Check that intrinsic calls without any mangling are converted to correct
+; mangled forms. And mangled forms with correct mangling parse correctly.
+define void @foo(float %a, i32 %b) {
+    ; CHECK:  call i1 @llvm.is.constant.f32
+    %c = call i1 @llvm.is.constant(float %a)
+
+    ; CHECK: call i1 @llvm.is.constant.i32
+    %d = call i1 @llvm.is.constant(i32 %b)
+
+    ; CHECK: call i1 @llvm.is.constant.i1
+    %e = call i1 @llvm.is.constant.i1(i1 false)
+
+    ret void
+}



More information about the llvm-commits mailing list