[llvm] [SandboxIR] Fix CmpInst::create() when it gets folded (PR #123408)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 17 13:52:07 PST 2025
https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/123408
If the operands of a CmpInst are constants then it gets folded into a constant. Therefore CmpInst::create() should return a Value*, not a Constant* and should handle the creation of the constant correctly.
>From 066e6fab3841ae4e8845526e4b4f2903a388b8bc Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Fri, 17 Jan 2025 13:34:44 -0800
Subject: [PATCH] [SandboxIR] Fix CmpInst::create() when it gets folded
If the operands of a CmpInst are constants then it gets folded into a constant.
Therefore CmpInst::create() should return a Value*, not a Constant* and should
handle the creation of the constant correctly.
---
llvm/include/llvm/SandboxIR/Instruction.h | 13 ++++----
llvm/lib/SandboxIR/Instruction.cpp | 33 +++++++++++---------
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 35 +++++++++++++++++-----
3 files changed, 53 insertions(+), 28 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/Instruction.h b/llvm/include/llvm/SandboxIR/Instruction.h
index 34a7feb63bec45..49ea6707ecd82f 100644
--- a/llvm/include/llvm/SandboxIR/Instruction.h
+++ b/llvm/include/llvm/SandboxIR/Instruction.h
@@ -2478,13 +2478,12 @@ class CmpInst : public SingleLLVMInstructionImpl<llvm::CmpInst> {
public:
using Predicate = llvm::CmpInst::Predicate;
- static CmpInst *create(Predicate Pred, Value *S1, Value *S2,
- InsertPosition Pos, Context &Ctx,
- const Twine &Name = "");
- static CmpInst *createWithCopiedFlags(Predicate Pred, Value *S1, Value *S2,
- const Instruction *FlagsSource,
- InsertPosition Pos, Context &Ctx,
- const Twine &Name = "");
+ static Value *create(Predicate Pred, Value *S1, Value *S2, InsertPosition Pos,
+ Context &Ctx, const Twine &Name = "");
+ static Value *createWithCopiedFlags(Predicate Pred, Value *S1, Value *S2,
+ const Instruction *FlagsSource,
+ InsertPosition Pos, Context &Ctx,
+ const Twine &Name = "");
void setPredicate(Predicate P);
void swapOperands();
diff --git a/llvm/lib/SandboxIR/Instruction.cpp b/llvm/lib/SandboxIR/Instruction.cpp
index 0a7cd95124bb51..cc961418600e3f 100644
--- a/llvm/lib/SandboxIR/Instruction.cpp
+++ b/llvm/lib/SandboxIR/Instruction.cpp
@@ -926,21 +926,26 @@ void PHINode::removeIncomingValueIf(function_ref<bool(unsigned)> Predicate) {
}
}
-CmpInst *CmpInst::create(Predicate P, Value *S1, Value *S2, InsertPosition Pos,
- Context &Ctx, const Twine &Name) {
+Value *CmpInst::create(Predicate P, Value *S1, Value *S2, InsertPosition Pos,
+ Context &Ctx, const Twine &Name) {
auto &Builder = setInsertPos(Pos);
- auto *LLVMI = Builder.CreateCmp(P, S1->Val, S2->Val, Name);
- if (dyn_cast<llvm::ICmpInst>(LLVMI))
- return Ctx.createICmpInst(cast<llvm::ICmpInst>(LLVMI));
- return Ctx.createFCmpInst(cast<llvm::FCmpInst>(LLVMI));
-}
-CmpInst *CmpInst::createWithCopiedFlags(Predicate P, Value *S1, Value *S2,
- const Instruction *F,
- InsertPosition Pos, Context &Ctx,
- const Twine &Name) {
- CmpInst *Inst = create(P, S1, S2, Pos, Ctx, Name);
- cast<llvm::CmpInst>(Inst->Val)->copyIRFlags(F->Val);
- return Inst;
+ auto *LLVMV = Builder.CreateCmp(P, S1->Val, S2->Val, Name);
+ // It may have been folded into a constant.
+ if (auto *LLVMC = dyn_cast<llvm::Constant>(LLVMV))
+ return Ctx.getOrCreateConstant(LLVMC);
+ if (isa<llvm::ICmpInst>(LLVMV))
+ return Ctx.createICmpInst(cast<llvm::ICmpInst>(LLVMV));
+ return Ctx.createFCmpInst(cast<llvm::FCmpInst>(LLVMV));
+}
+
+Value *CmpInst::createWithCopiedFlags(Predicate P, Value *S1, Value *S2,
+ const Instruction *F, InsertPosition Pos,
+ Context &Ctx, const Twine &Name) {
+ Value *V = create(P, S1, S2, Pos, Ctx, Name);
+ if (auto *C = dyn_cast<Constant>(V))
+ return C;
+ cast<llvm::CmpInst>(V->Val)->copyIRFlags(F->Val);
+ return V;
}
Type *CmpInst::makeCmpResultType(Type *OpndType) {
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 874c32c2d4398f..73e8ef283fc2ae 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -5841,9 +5841,9 @@ define void @foo(i32 %i0, i32 %i1) {
EXPECT_EQ(ICmp->getSignedPredicate(), LLVMICmp->getSignedPredicate());
EXPECT_EQ(ICmp->getUnsignedPredicate(), LLVMICmp->getUnsignedPredicate());
}
- auto *NewCmp =
+ auto *NewCmp = cast<sandboxir::CmpInst>(
sandboxir::CmpInst::create(llvm::CmpInst::ICMP_ULE, F.getArg(0),
- F.getArg(1), BB->begin(), Ctx, "NewCmp");
+ F.getArg(1), BB->begin(), Ctx, "NewCmp"));
EXPECT_EQ(NewCmp, &*BB->begin());
EXPECT_EQ(NewCmp->getPredicate(), llvm::CmpInst::ICMP_ULE);
EXPECT_EQ(NewCmp->getOperand(0), F.getArg(0));
@@ -5856,6 +5856,16 @@ define void @foo(i32 %i0, i32 %i1) {
sandboxir::Type *RT =
sandboxir::CmpInst::makeCmpResultType(F.getArg(0)->getType());
EXPECT_TRUE(RT->isIntegerTy(1)); // Only one bit in a single comparison
+
+ {
+ // Check create() when operands are constant.
+ auto *Const42 =
+ sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 42);
+ auto *NewConstCmp =
+ sandboxir::CmpInst::create(llvm::CmpInst::ICMP_ULE, Const42, Const42,
+ BB->begin(), Ctx, "NewConstCmp");
+ EXPECT_TRUE(isa<sandboxir::Constant>(NewConstCmp));
+ }
}
TEST_F(SandboxIRTest, FCmpInst) {
@@ -5906,8 +5916,8 @@ define void @foo(float %f0, float %f1) {
CopyFrom->setFastMathFlags(FastMathFlags::getFast());
// create with default flags
- auto *NewFCmp = sandboxir::CmpInst::create(
- llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), It1, Ctx, "NewFCmp");
+ auto *NewFCmp = cast<sandboxir::CmpInst>(sandboxir::CmpInst::create(
+ llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), It1, Ctx, "NewFCmp"));
EXPECT_EQ(NewFCmp->getPredicate(), llvm::CmpInst::FCMP_ONE);
EXPECT_EQ(NewFCmp->getOperand(0), F.getArg(0));
EXPECT_EQ(NewFCmp->getOperand(1), F.getArg(1));
@@ -5917,9 +5927,10 @@ define void @foo(float %f0, float %f1) {
FastMathFlags DefaultFMF = NewFCmp->getFastMathFlags();
EXPECT_TRUE(CopyFrom->getFastMathFlags() != DefaultFMF);
// create with copied flags
- auto *NewFCmpFlags = sandboxir::CmpInst::createWithCopiedFlags(
- llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), CopyFrom, It1, Ctx,
- "NewFCmpFlags");
+ auto *NewFCmpFlags =
+ cast<sandboxir::CmpInst>(sandboxir::CmpInst::createWithCopiedFlags(
+ llvm::CmpInst::FCMP_ONE, F.getArg(0), F.getArg(1), CopyFrom, It1, Ctx,
+ "NewFCmpFlags"));
EXPECT_FALSE(NewFCmpFlags->getFastMathFlags() !=
CopyFrom->getFastMathFlags());
EXPECT_EQ(NewFCmpFlags->getPredicate(), llvm::CmpInst::FCMP_ONE);
@@ -5928,6 +5939,16 @@ define void @foo(float %f0, float %f1) {
#ifndef NDEBUG
EXPECT_EQ(NewFCmpFlags->getName(), "NewFCmpFlags");
#endif // NDEBUG
+
+ {
+ // Check create() when operands are constant.
+ auto *Const42 =
+ sandboxir::ConstantFP::get(sandboxir::Type::getFloatTy(Ctx), 42.0);
+ auto *NewConstCmp =
+ sandboxir::CmpInst::create(llvm::CmpInst::FCMP_ULE, Const42, Const42,
+ BB->begin(), Ctx, "NewConstCmp");
+ EXPECT_TRUE(isa<sandboxir::Constant>(NewConstCmp));
+ }
}
TEST_F(SandboxIRTest, UnreachableInst) {
More information about the llvm-commits
mailing list