[llvm] [GlobalISel] Take the result size into account when const folding icmp (PR #134365)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 15 07:31:37 PDT 2025
https://github.com/KRM7 updated https://github.com/llvm/llvm-project/pull/134365
>From 49ce91e499c4aba289cf19c9eb1988e34da36c8d Mon Sep 17 00:00:00 2001
From: Krisztian Rugasi <Krisztian.Rugasi at hightec-rt.com>
Date: Tue, 15 Apr 2025 16:28:58 +0200
Subject: [PATCH] [GlobalISel] Take the result size into account when const
folding icmp
---
llvm/include/llvm/CodeGen/GlobalISel/Utils.h | 1 +
llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp | 8 ++--
llvm/lib/CodeGen/GlobalISel/Utils.cpp | 45 ++++++++++++-------
llvm/unittests/CodeGen/GlobalISel/CSETest.cpp | 45 +++++++++++++++++++
4 files changed, 79 insertions(+), 20 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
index 44141844f42f4..35d21aa1d66d9 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
@@ -325,6 +325,7 @@ ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
std::optional<SmallVector<APInt>>
ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
+ unsigned DstScalarSizeInBits, unsigned ExtOp,
const MachineRegisterInfo &MRI);
/// Test if the given value is known to have exactly one bit set. This differs
diff --git a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
index bf8e847011d7c..10c72641ce2df 100644
--- a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
@@ -189,10 +189,12 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc,
assert(SrcOps.size() == 3 && "Invalid sources");
assert(DstOps.size() == 1 && "Invalid dsts");
LLT SrcTy = SrcOps[1].getLLTTy(*getMRI());
+ LLT DstTy = DstOps[0].getLLTTy(*getMRI());
+ auto BoolExtOp = getBoolExtOp(SrcTy.isVector(), false);
- if (std::optional<SmallVector<APInt>> Cst =
- ConstantFoldICmp(SrcOps[0].getPredicate(), SrcOps[1].getReg(),
- SrcOps[2].getReg(), *getMRI())) {
+ if (std::optional<SmallVector<APInt>> Cst = ConstantFoldICmp(
+ SrcOps[0].getPredicate(), SrcOps[1].getReg(), SrcOps[2].getReg(),
+ DstTy.getScalarSizeInBits(), BoolExtOp, *getMRI())) {
if (SrcTy.isVector())
return buildBuildVectorConstant(DstOps[0], *Cst);
return buildConstant(DstOps[0], Cst->front());
diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index 223d69c362185..04e69f882bb43 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -1027,39 +1027,50 @@ llvm::ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
std::optional<SmallVector<APInt>>
llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
+ unsigned DstScalarSizeInBits, unsigned ExtOp,
const MachineRegisterInfo &MRI) {
- LLT Ty = MRI.getType(Op1);
- if (Ty != MRI.getType(Op2))
- return std::nullopt;
+ assert(ExtOp == TargetOpcode::G_SEXT || ExtOp == TargetOpcode::G_ZEXT ||
+ ExtOp == TargetOpcode::G_ANYEXT);
- auto TryFoldScalar = [&MRI, Pred](Register LHS,
- Register RHS) -> std::optional<APInt> {
- auto LHSCst = getIConstantVRegVal(LHS, MRI);
+ const LLT Ty = MRI.getType(Op1);
+
+ auto GetICmpResultCst = [&](bool IsTrue) {
+ if (IsTrue)
+ return ExtOp == TargetOpcode::G_SEXT
+ ? APInt::getAllOnes(DstScalarSizeInBits)
+ : APInt::getOneBitSet(DstScalarSizeInBits, 0);
+ return APInt::getZero(DstScalarSizeInBits);
+ };
+
+ auto TryFoldScalar = [&](Register LHS, Register RHS) -> std::optional<APInt> {
auto RHSCst = getIConstantVRegVal(RHS, MRI);
- if (!LHSCst || !RHSCst)
+ if (!RHSCst)
+ return std::nullopt;
+ auto LHSCst = getIConstantVRegVal(LHS, MRI);
+ if (!LHSCst)
return std::nullopt;
switch (Pred) {
case CmpInst::Predicate::ICMP_EQ:
- return APInt(/*numBits=*/1, LHSCst->eq(*RHSCst));
+ return GetICmpResultCst(LHSCst->eq(*RHSCst));
case CmpInst::Predicate::ICMP_NE:
- return APInt(/*numBits=*/1, LHSCst->ne(*RHSCst));
+ return GetICmpResultCst(LHSCst->ne(*RHSCst));
case CmpInst::Predicate::ICMP_UGT:
- return APInt(/*numBits=*/1, LHSCst->ugt(*RHSCst));
+ return GetICmpResultCst(LHSCst->ugt(*RHSCst));
case CmpInst::Predicate::ICMP_UGE:
- return APInt(/*numBits=*/1, LHSCst->uge(*RHSCst));
+ return GetICmpResultCst(LHSCst->uge(*RHSCst));
case CmpInst::Predicate::ICMP_ULT:
- return APInt(/*numBits=*/1, LHSCst->ult(*RHSCst));
+ return GetICmpResultCst(LHSCst->ult(*RHSCst));
case CmpInst::Predicate::ICMP_ULE:
- return APInt(/*numBits=*/1, LHSCst->ule(*RHSCst));
+ return GetICmpResultCst(LHSCst->ule(*RHSCst));
case CmpInst::Predicate::ICMP_SGT:
- return APInt(/*numBits=*/1, LHSCst->sgt(*RHSCst));
+ return GetICmpResultCst(LHSCst->sgt(*RHSCst));
case CmpInst::Predicate::ICMP_SGE:
- return APInt(/*numBits=*/1, LHSCst->sge(*RHSCst));
+ return GetICmpResultCst(LHSCst->sge(*RHSCst));
case CmpInst::Predicate::ICMP_SLT:
- return APInt(/*numBits=*/1, LHSCst->slt(*RHSCst));
+ return GetICmpResultCst(LHSCst->slt(*RHSCst));
case CmpInst::Predicate::ICMP_SLE:
- return APInt(/*numBits=*/1, LHSCst->sle(*RHSCst));
+ return GetICmpResultCst(LHSCst->sle(*RHSCst));
default:
return std::nullopt;
}
diff --git a/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp b/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp
index cd6e32311a9ee..7c29c9d419c08 100644
--- a/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp
@@ -500,6 +500,18 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
EXPECT_TRUE(I->getOperand(1).getCImm()->getZExtValue());
}
+ {
+ auto I = CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, s32, One, One);
+ EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_CONSTANT);
+ EXPECT_EQ(I->getOperand(1).getCImm()->getZExtValue(), 1);
+ }
+
+ {
+ auto I = CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, s32, One, Two);
+ EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_CONSTANT);
+ EXPECT_EQ(I->getOperand(1).getCImm()->getZExtValue(), 0);
+ }
+
LLT VecTy = LLT::fixed_vector(2, s32);
LLT DstTy = LLT::fixed_vector(2, s1);
auto Three = CSEB.buildConstant(s32, 3);
@@ -508,6 +520,8 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
auto OneTwo = CSEB.buildBuildVector(VecTy, {One.getReg(0), Two.getReg(0)});
auto TwoThree =
CSEB.buildBuildVector(VecTy, {Two.getReg(0), Three.getReg(0)});
+ auto OneThree =
+ CSEB.buildBuildVector(VecTy, {One.getReg(0), Three.getReg(0)});
auto MinusOneOne =
CSEB.buildBuildVector(VecTy, {MinusOne.getReg(0), MinusOne.getReg(0)});
auto MinusOneTwo =
@@ -547,6 +561,36 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
// ICMP_SLE
CSEB.buildICmp(CmpInst::Predicate::ICMP_SLE, DstTy, MinusOneTwo, MinusOneOne);
+ {
+ auto I =
+ CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, VecTy, OneOne, TwoThree);
+ EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
+ const APInt HiCst = *getIConstantVRegVal(I->getOperand(1).getReg(), *MRI);
+ const APInt LoCst = *getIConstantVRegVal(I->getOperand(2).getReg(), *MRI);
+ EXPECT_EQ(HiCst.getSExtValue(), 0);
+ EXPECT_EQ(LoCst.getSExtValue(), 0);
+ }
+
+ {
+ auto I =
+ CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, VecTy, OneThree, TwoThree);
+ EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
+ const APInt HiCst = *getIConstantVRegVal(I->getOperand(1).getReg(), *MRI);
+ const APInt LoCst = *getIConstantVRegVal(I->getOperand(2).getReg(), *MRI);
+ EXPECT_EQ(HiCst.getSExtValue(), 0);
+ EXPECT_EQ(LoCst.getSExtValue(), -1);
+ }
+
+ {
+ auto I =
+ CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, VecTy, TwoThree, TwoThree);
+ EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
+ const APInt HiCst = *getIConstantVRegVal(I->getOperand(1).getReg(), *MRI);
+ const APInt LoCst = *getIConstantVRegVal(I->getOperand(2).getReg(), *MRI);
+ EXPECT_EQ(HiCst.getSExtValue(), -1);
+ EXPECT_EQ(LoCst.getSExtValue(), -1);
+ }
+
auto CheckStr = R"(
; CHECK: [[One:%[0-9]+]]:_(s32) = G_CONSTANT i32 1
; CHECK: [[Two:%[0-9]+]]:_(s32) = G_CONSTANT i32 2
@@ -558,6 +602,7 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[One]]:_(s32), [[One]]:_(s32)
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[One]]:_(s32), [[Two]]:_(s32)
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[Two]]:_(s32), [[Three]]:_(s32)
+ ; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[One]]:_(s32), [[Three]]:_(s32)
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[MinusOne]]:_(s32), [[MinusOne]]:_(s32)
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[MinusOne]]:_(s32), [[MinusTwo]]:_(s32)
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[MinusTwo]]:_(s32), [[MinusThree]]:_(s32)
More information about the llvm-commits
mailing list