[llvm] [SandboxIR] Fix CmpInst::create() when it gets folded (PR #123408)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 17 13:52:07 PST 2025


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/123408

If the operands of a CmpInst are constants then it gets folded into a constant. Therefore CmpInst::create() should return a Value*, not a Constant* and should handle the creation of the constant correctly.

>From 066e6fab3841ae4e8845526e4b4f2903a388b8bc Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Fri, 17 Jan 2025 13:34:44 -0800
Subject: [PATCH] [SandboxIR] Fix CmpInst::create() when it gets folded

If the operands of a CmpInst are constants then it gets folded into a constant.
Therefore CmpInst::create() should return a Value*, not a Constant* and should
handle the creation of the constant correctly.
---
 llvm/include/llvm/SandboxIR/Instruction.h  | 13 ++++----
 llvm/lib/SandboxIR/Instruction.cpp         | 33 +++++++++++---------
 llvm/unittests/SandboxIR/SandboxIRTest.cpp | 35 +++++++++++++++++-----
 3 files changed, 53 insertions(+), 28 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Instruction.h b/llvm/include/llvm/SandboxIR/Instruction.h
index 34a7feb63bec45..49ea6707ecd82f 100644
--- a/llvm/include/llvm/SandboxIR/Instruction.h
+++ b/llvm/include/llvm/SandboxIR/Instruction.h
@@ -2478,13 +2478,12 @@ class CmpInst : public SingleLLVMInstructionImpl<llvm::CmpInst> {
 public:
   using Predicate = llvm::CmpInst::Predicate;
 
-  static CmpInst *create(Predicate Pred, Value *S1, Value *S2,
-                         InsertPosition Pos, Context &Ctx,
-                         const Twine &Name = "");
-  static CmpInst *createWithCopiedFlags(Predicate Pred, Value *S1, Value *S2,
-                                        const Instruction *FlagsSource,
-                                        InsertPosition Pos, Context &Ctx,
-                                        const Twine &Name = "");
+  static Value *create(Predicate Pred, Value *S1, Value *S2, InsertPosition Pos,
+                       Context &Ctx, const Twine &Name = "");
+  static Value *createWithCopiedFlags(Predicate Pred, Value *S1, Value *S2,
+                                      const Instruction *FlagsSource,
+                                      InsertPosition Pos, Context &Ctx,
+                                      const Twine &Name = "");
   void setPredicate(Predicate P);
   void swapOperands();
 
diff --git a/llvm/lib/SandboxIR/Instruction.cpp b/llvm/lib/SandboxIR/Instruction.cpp
index 0a7cd95124bb51..cc961418600e3f 100644
--- a/llvm/lib/SandboxIR/Instruction.cpp
+++ b/llvm/lib/SandboxIR/Instruction.cpp
@@ -926,21 +926,26 @@ void PHINode::removeIncomingValueIf(function_ref<bool(unsigned)> Predicate) {
   }
 }
 
-CmpInst *CmpInst::create(Predicate P, Value *S1, Value *S2, InsertPosition Pos,
-                         Context &Ctx, const Twine &Name) {
+Value *CmpInst::create(Predicate P, Value *S1, Value *S2, InsertPosition Pos,
+                       Context &Ctx, const Twine &Name) {
   auto &Builder = setInsertPos(Pos);
-  auto *LLVMI = Builder.CreateCmp(P, S1->Val, S2->Val, Name);
-  if (dyn_cast<llvm::ICmpInst>(LLVMI))
-    return Ctx.createICmpInst(cast<llvm::ICmpInst>(LLVMI));
-  return Ctx.createFCmpInst(cast<llvm::FCmpInst>(LLVMI));
-}
-CmpInst *CmpInst::createWithCopiedFlags(Predicate P, Value *S1, Value *S2,
-                                        const Instruction *F,
-                                        InsertPosition Pos, Context &Ctx,
-                                        const Twine &Name) {
-  CmpInst *Inst = create(P, S1, S2, Pos, Ctx, Name);
-  cast<llvm::CmpInst>(Inst->Val)->copyIRFlags(F->Val);
-  return Inst;
+  auto *LLVMV = Builder.CreateCmp(P, S1->Val, S2->Val, Name);
+  // It may have been folded into a constant.
+  if (auto *LLVMC = dyn_cast<llvm::Constant>(LLVMV))
+    return Ctx.getOrCreateConstant(LLVMC);
+  if (isa<llvm::ICmpInst>(LLVMV))
+    return Ctx.createICmpInst(cast<llvm::ICmpInst>(LLVMV));
+  return Ctx.createFCmpInst(cast<llvm::FCmpInst>(LLVMV));
+}
+
+Value *CmpInst::createWithCopiedFlags(Predicate P, Value *S1, Value *S2,
+                                      const Instruction *F, InsertPosition Pos,
+                                      Context &Ctx, const Twine &Name) {
+  Value *V = create(P, S1, S2, Pos, Ctx, Name);
+  if (auto *C = dyn_cast<Constant>(V))
+    return C;
+  cast<llvm::CmpInst>(V->Val)->copyIRFlags(F->Val);
+  return V;
 }
 
 Type *CmpInst::makeCmpResultType(Type *OpndType) {
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 874c32c2d4398f..73e8ef283fc2ae 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -5841,9 +5841,9 @@ define void @foo(i32 %i0, i32 %i1) {
     EXPECT_EQ(ICmp->getSignedPredicate(), LLVMICmp->getSignedPredicate());
     EXPECT_EQ(ICmp->getUnsignedPredicate(), LLVMICmp->getUnsignedPredicate());
   }
-  auto *NewCmp =
+  auto *NewCmp = cast<sandboxir::CmpInst>(
       sandboxir::CmpInst::create(llvm::CmpInst::ICMP_ULE, F.getArg(0),
-                                 F.getArg(1), BB->begin(), Ctx, "NewCmp");
+                                 F.getArg(1), BB->begin(), Ctx, "NewCmp"));
   EXPECT_EQ(NewCmp, &*BB->begin());
   EXPECT_EQ(NewCmp->getPredicate(), llvm::CmpInst::ICMP_ULE);
   EXPECT_EQ(NewCmp->getOperand(0), F.getArg(0));
@@ -5856,6 +5856,16 @@ define void @foo(i32 %i0, i32 %i1) {
   sandboxir::Type *RT =
       sandboxir::CmpInst::makeCmpResultType(F.getArg(0)->getType());
   EXPECT_TRUE(RT->isIntegerTy(1)); // Only one bit in a single comparison
+
+  {
+    // Check create() when operands are constant.
+    auto *Const42 =
+        sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 42);
+    auto *NewConstCmp =
+        sandboxir::CmpInst::create(llvm::CmpInst::ICMP_ULE, Const42, Const42,
+                                   BB->begin(), Ctx, "NewConstCmp");
+    EXPECT_TRUE(isa<sandboxir::Constant>(NewConstCmp));
+  }
 }
 
 TEST_F(SandboxIRTest, FCmpInst) {
@@ -5906,8 +5916,8 @@ define void @foo(float %f0, float %f1) {
   CopyFrom->setFastMathFlags(FastMathFlags::getFast());
 
   // create with default flags
-  auto *NewFCmp = sandboxir::CmpInst::create(
-      llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), It1, Ctx, "NewFCmp");
+  auto *NewFCmp = cast<sandboxir::CmpInst>(sandboxir::CmpInst::create(
+      llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), It1, Ctx, "NewFCmp"));
   EXPECT_EQ(NewFCmp->getPredicate(), llvm::CmpInst::FCMP_ONE);
   EXPECT_EQ(NewFCmp->getOperand(0), F.getArg(0));
   EXPECT_EQ(NewFCmp->getOperand(1), F.getArg(1));
@@ -5917,9 +5927,10 @@ define void @foo(float %f0, float %f1) {
   FastMathFlags DefaultFMF = NewFCmp->getFastMathFlags();
   EXPECT_TRUE(CopyFrom->getFastMathFlags() != DefaultFMF);
   // create with copied flags
-  auto *NewFCmpFlags = sandboxir::CmpInst::createWithCopiedFlags(
-      llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), CopyFrom, It1, Ctx,
-      "NewFCmpFlags");
+  auto *NewFCmpFlags =
+      cast<sandboxir::CmpInst>(sandboxir::CmpInst::createWithCopiedFlags(
+          llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), CopyFrom, It1, Ctx,
+          "NewFCmpFlags"));
   EXPECT_FALSE(NewFCmpFlags->getFastMathFlags() !=
                CopyFrom->getFastMathFlags());
   EXPECT_EQ(NewFCmpFlags->getPredicate(), llvm::CmpInst::FCMP_ONE);
@@ -5928,6 +5939,16 @@ define void @foo(float %f0, float %f1) {
 #ifndef NDEBUG
   EXPECT_EQ(NewFCmpFlags->getName(), "NewFCmpFlags");
 #endif // NDEBUG
+
+  {
+    // Check create() when operands are constant.
+    auto *Const42 =
+        sandboxir::ConstantFP::get(sandboxir::Type::getFloatTy(Ctx), 42.0);
+    auto *NewConstCmp =
+        sandboxir::CmpInst::create(llvm::CmpInst::FCMP_ULE, Const42, Const42,
+                                   BB->begin(), Ctx, "NewConstCmp");
+    EXPECT_TRUE(isa<sandboxir::Constant>(NewConstCmp));
+  }
 }
 
 TEST_F(SandboxIRTest, UnreachableInst) {



More information about the llvm-commits mailing list