[llvm] llvm.lround: Update verifier to validate support of vector types. (PR #98950)

Sumanth Gundapaneni via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 15 14:12:17 PDT 2024


https://github.com/sgundapa updated https://github.com/llvm/llvm-project/pull/98950

>From 4d308a9520d991b7add6cca14a3b2924785f12c3 Mon Sep 17 00:00:00 2001
From: Sumanth Gundapaneni <sumanth.gundapaneni at amd.com>
Date: Fri, 12 Jul 2024 14:10:48 -0500
Subject: [PATCH] llvm.lround: Update verifier to validate support of vector
 types.

(cherry picked from commit 12ed1a477c451584f840978af1f34ba0c98d5215)
---
 llvm/lib/CodeGen/MachineVerifier.cpp         | 15 ++++-
 llvm/lib/IR/Verifier.cpp                     | 17 +++++-
 llvm/test/MachineVerifier/test_g_llround.mir | 16 +++--
 llvm/test/MachineVerifier/test_g_lround.mir  |  8 ++-
 llvm/unittests/IR/IntrinsicsTest.cpp         | 64 ++++++++++++++++++++
 5 files changed, 109 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/CodeGen/MachineVerifier.cpp b/llvm/lib/CodeGen/MachineVerifier.cpp
index 0a5b8bdbc9371..3ae2f7951cefa 100644
--- a/llvm/lib/CodeGen/MachineVerifier.cpp
+++ b/llvm/lib/CodeGen/MachineVerifier.cpp
@@ -2032,7 +2032,20 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {
   }
   case TargetOpcode::G_LLROUND:
   case TargetOpcode::G_LROUND: {
-    verifyAllRegOpsScalar(*MI, *MRI);
+    LLT DstTy = MRI->getType(MI->getOperand(0).getReg());
+    LLT SrcTy = MRI->getType(MI->getOperand(1).getReg());
+    if (!DstTy.isValid() || !SrcTy.isValid())
+      break;
+    if (SrcTy.isPointer() || DstTy.isPointer()) {
+      std::string Op = SrcTy.isPointer() ? "Source" : "Destination";
+      report(Twine(Op, " operand must not be a pointer type"), MI);
+    } else if (SrcTy.isScalar()) {
+      verifyAllRegOpsScalar(*MI, *MRI);
+      break;
+    } else if (SrcTy.isVector()) {
+      verifyVectorElementMatch(SrcTy, DstTy, MI);
+      break;
+    }
     break;
   }
   case TargetOpcode::G_IS_FPCLASS: {
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 75a53c1c99734..76fe759fc1f63 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -5963,8 +5963,21 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
   case Intrinsic::llround: {
     Type *ValTy = Call.getArgOperand(0)->getType();
     Type *ResultTy = Call.getType();
-    Check(!ValTy->isVectorTy() && !ResultTy->isVectorTy(),
-          "Intrinsic does not support vectors", &Call);
+    Check(
+        ValTy->isFPOrFPVectorTy() && ResultTy->isIntOrIntVectorTy(),
+        "llvm.lround, llvm.llround: argument must be floating-point or vector "
+        "of floating-points, and result must be integer or vector of integers",
+        &Call);
+    Check(
+        ValTy->isVectorTy() == ResultTy->isVectorTy(),
+        "llvm.lround, llvm.llround: argument and result disagree on vector use",
+        &Call);
+    if (ValTy->isVectorTy()) {
+      Check(cast<VectorType>(ValTy)->getElementCount() ==
+                cast<VectorType>(ResultTy)->getElementCount(),
+            "llvm.lround, llvm.llround: argument must be same length as result",
+            &Call);
+    }
     break;
   }
   case Intrinsic::bswap: {
diff --git a/llvm/test/MachineVerifier/test_g_llround.mir b/llvm/test/MachineVerifier/test_g_llround.mir
index 9a0f4a75acaf4..e69499b1150c1 100644
--- a/llvm/test/MachineVerifier/test_g_llround.mir
+++ b/llvm/test/MachineVerifier/test_g_llround.mir
@@ -14,10 +14,14 @@ body:             |
     %ptr:_(p0) = COPY $x0
     %vector:_(<2 x s64>) = COPY $q0
 
-    ; CHECK: Bad machine code: All register operands must have scalar types
-    ; CHECK: instruction: %no_ptrs:_(s64) = G_LROUND %ptr:_(p0)
-    %no_ptrs:_(s64) = G_LROUND %ptr:_(p0)
+    ; CHECK: Bad machine code: Source operand must not be a pointer type
+    ; CHECK: instruction: %no_ptrs:_(s32) = G_LLROUND %ptr:_(p0)
+    %no_ptrs:_(s32) = G_LLROUND %ptr:_(p0)
 
-    ; CHECK: Bad machine code: All register operands must have scalar types
-    ; CHECK: instruction: %no_vectors:_(s64) = G_LROUND %vector:_(<2 x s64>)
-    %no_vectors:_(s64) = G_LROUND %vector:_(<2 x s64>)
+    ; CHECK: Bad machine code: operand types must be all-vector or all-scalar
+    ; CHECK: instruction: %no_vectors:_(s32) = G_LLROUND %vector:_(<2 x s64>)
+    %no_vectors:_(s32) = G_LLROUND %vector:_(<2 x s64>)
+
+    ; CHECK: Bad machine code: operand types must preserve number of vector elements
+    ; CHECK: instruction: %inv_vectors:_(<3 x s32>) = G_LLROUND %vector:_(<2 x s64>)
+    %inv_vectors:_(<3 x s32>) = G_LLROUND %vector:_(<2 x s64>)
diff --git a/llvm/test/MachineVerifier/test_g_lround.mir b/llvm/test/MachineVerifier/test_g_lround.mir
index 69d5d4967de30..56f06f00049e7 100644
--- a/llvm/test/MachineVerifier/test_g_lround.mir
+++ b/llvm/test/MachineVerifier/test_g_lround.mir
@@ -14,10 +14,14 @@ body:             |
     %ptr:_(p0) = COPY $x0
     %vector:_(<2 x s64>) = COPY $q0
 
-    ; CHECK: Bad machine code: All register operands must have scalar types
+    ; CHECK: Bad machine code: Source operand must not be a pointer type
     ; CHECK: instruction: %no_ptrs:_(s32) = G_LROUND %ptr:_(p0)
     %no_ptrs:_(s32) = G_LROUND %ptr:_(p0)
 
-    ; CHECK: Bad machine code: All register operands must have scalar types
+    ; CHECK: Bad machine code: operand types must be all-vector or all-scalar
     ; CHECK: instruction: %no_vectors:_(s32) = G_LROUND %vector:_(<2 x s64>)
     %no_vectors:_(s32) = G_LROUND %vector:_(<2 x s64>)
+
+    ; CHECK: Bad machine code: operand types must preserve number of vector elements
+    ; CHECK: instruction: %inv_vectors:_(<3 x s32>) = G_LROUND %vector:_(<2 x s64>)
+    %inv_vectors:_(<3 x s32>) = G_LROUND %vector:_(<2 x s64>)
diff --git a/llvm/unittests/IR/IntrinsicsTest.cpp b/llvm/unittests/IR/IntrinsicsTest.cpp
index 6f9e724c40326..14badaa0de980 100644
--- a/llvm/unittests/IR/IntrinsicsTest.cpp
+++ b/llvm/unittests/IR/IntrinsicsTest.cpp
@@ -12,6 +12,7 @@
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/Verifier.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
@@ -106,4 +107,67 @@ TEST_F(IntrinsicsTest, InstrProfInheritance) {
     EXPECT_TRUE(Checker(*Intr));
   }
 }
+
+TEST(IntrinsicVerifierTest, LRound) {
+  LLVMContext C;
+  std::unique_ptr<Module> M = std::make_unique<Module>("M", C);
+  IRBuilder<> Builder(C);
+
+  using TypePair = std::pair<Type *, Type *>;
+  Type *Int32Ty = Type::getInt32Ty(C);
+  Type *Int64Ty = Type::getInt64Ty(C);
+  Type *HalfTy = Type::getHalfTy(C);
+  Type *FltTy = Type::getFloatTy(C);
+  Type *DblTy = Type::getDoubleTy(C);
+  auto Vec2xTy = [&](Type *ElemTy) {
+    return VectorType::get(ElemTy, ElementCount::getFixed(2));
+  };
+  Type *Vec2xInt32Ty = Vec2xTy(Int32Ty);
+  Type *Vec2xInt64Ty = Vec2xTy(Int64Ty);
+  Type *Vec2xFltTy = Vec2xTy(FltTy);
+
+  // Test Cases
+  // Validating only a limited set of possible combinations.
+  std::vector<TypePair> ValidTypes = {
+      {Int32Ty, FltTy},          {Int32Ty, DblTy},  {Int64Ty, FltTy},
+      {Int64Ty, DblTy},          {Int32Ty, HalfTy}, {Vec2xInt32Ty, Vec2xFltTy},
+      {Vec2xInt64Ty, Vec2xFltTy}};
+
+  // CreateIntrinsic errors out on invalid argument types.
+  std::vector<TypePair> InvalidTypes = {
+      {VectorType::get(Int32Ty, ElementCount::getFixed(3)), Vec2xFltTy}};
+
+  auto testIntrinsic = [&](TypePair types, Intrinsic::ID ID, bool expectValid) {
+    Function *F =
+        Function::Create(FunctionType::get(types.first, {types.second}, false),
+                         Function::ExternalLinkage, "lround_fn", M.get());
+    BasicBlock *BB = BasicBlock::Create(C, "entry", F);
+    Builder.SetInsertPoint(BB);
+
+    Value *Arg = F->arg_begin();
+    Value *Result = Builder.CreateIntrinsic(types.first, ID, {Arg});
+    Builder.CreateRet(Result);
+
+    std::string Error;
+    raw_string_ostream ErrorOS(Error);
+    EXPECT_EQ(expectValid, !verifyFunction(*F, &ErrorOS));
+    if (!expectValid) {
+      EXPECT_TRUE(StringRef(ErrorOS.str())
+                      .contains("llvm.lround, llvm.llround: argument must be "
+                                "same length as result"));
+    }
+  };
+
+  // Run Valid Cases.
+  for (auto Types : ValidTypes) {
+    testIntrinsic(Types, Intrinsic::lround, true);
+    testIntrinsic(Types, Intrinsic::llround, true);
+  }
+
+  // Run Invalid Cases.
+  for (auto Types : InvalidTypes) {
+    testIntrinsic(Types, Intrinsic::lround, false);
+    testIntrinsic(Types, Intrinsic::llround, false);
+  }
+}
 } // end namespace



More information about the llvm-commits mailing list