[llvm] [GlobalISel] Take the result size into account when const folding icmp (PR #134365)

via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 4 04:56:04 PDT 2025


https://github.com/KRM7 updated https://github.com/llvm/llvm-project/pull/134365

>From a2a65fdffe3f01aa50d8c324caf10743d4fe080c Mon Sep 17 00:00:00 2001
From: Krisztian Rugasi <Krisztian.Rugasi at hightec-rt.com>
Date: Fri, 4 Apr 2025 13:55:25 +0200
Subject: [PATCH] [GlobalISel] Take the result size into account when const
 folding icmp

---
 llvm/include/llvm/CodeGen/GlobalISel/Utils.h  |  1 +
 llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp |  8 +++--
 llvm/lib/CodeGen/GlobalISel/Utils.cpp         | 30 +++++++++++--------
 llvm/unittests/CodeGen/GlobalISel/CSETest.cpp | 12 ++++++++
 4 files changed, 36 insertions(+), 15 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
index 44141844f42f4..3260b50ce317c 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
@@ -325,6 +325,7 @@ ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
 
 std::optional<SmallVector<APInt>>
 ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
+                 unsigned DstSizeInBits, unsigned ExtOp,
                  const MachineRegisterInfo &MRI);
 
 /// Test if the given value is known to have exactly one bit set. This differs
diff --git a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
index bf8e847011d7c..10c72641ce2df 100644
--- a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
@@ -189,10 +189,12 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc,
     assert(SrcOps.size() == 3 && "Invalid sources");
     assert(DstOps.size() == 1 && "Invalid dsts");
     LLT SrcTy = SrcOps[1].getLLTTy(*getMRI());
+    LLT DstTy = DstOps[0].getLLTTy(*getMRI());
+    auto BoolExtOp = getBoolExtOp(SrcTy.isVector(), false);
 
-    if (std::optional<SmallVector<APInt>> Cst =
-            ConstantFoldICmp(SrcOps[0].getPredicate(), SrcOps[1].getReg(),
-                             SrcOps[2].getReg(), *getMRI())) {
+    if (std::optional<SmallVector<APInt>> Cst = ConstantFoldICmp(
+            SrcOps[0].getPredicate(), SrcOps[1].getReg(), SrcOps[2].getReg(),
+            DstTy.getScalarSizeInBits(), BoolExtOp, *getMRI())) {
       if (SrcTy.isVector())
         return buildBuildVectorConstant(DstOps[0], *Cst);
       return buildConstant(DstOps[0], Cst->front());
diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index 223d69c362185..2ae583c3f81bf 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -1027,13 +1027,19 @@ llvm::ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
 
 std::optional<SmallVector<APInt>>
 llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
+                       unsigned DstSizeInBits, unsigned ExtOp,
                        const MachineRegisterInfo &MRI) {
+  assert(ExtOp == TargetOpcode::G_SEXT || ExtOp == TargetOpcode::G_ZEXT ||
+         ExtOp == TargetOpcode::G_ANYEXT);
+
   LLT Ty = MRI.getType(Op1);
   if (Ty != MRI.getType(Op2))
     return std::nullopt;
 
-  auto TryFoldScalar = [&MRI, Pred](Register LHS,
-                                    Register RHS) -> std::optional<APInt> {
+  const int64_t Sign = ExtOp == TargetOpcode::G_SEXT ? -1 : 1;
+
+  auto TryFoldScalar = [&MRI, Pred, DstSizeInBits, Sign](
+                           Register LHS, Register RHS) -> std::optional<APInt> {
     auto LHSCst = getIConstantVRegVal(LHS, MRI);
     auto RHSCst = getIConstantVRegVal(RHS, MRI);
     if (!LHSCst || !RHSCst)
@@ -1041,25 +1047,25 @@ llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
 
     switch (Pred) {
     case CmpInst::Predicate::ICMP_EQ:
-      return APInt(/*numBits=*/1, LHSCst->eq(*RHSCst));
+      return APInt(DstSizeInBits, Sign * LHSCst->eq(*RHSCst));
     case CmpInst::Predicate::ICMP_NE:
-      return APInt(/*numBits=*/1, LHSCst->ne(*RHSCst));
+      return APInt(DstSizeInBits, Sign * LHSCst->ne(*RHSCst));
     case CmpInst::Predicate::ICMP_UGT:
-      return APInt(/*numBits=*/1, LHSCst->ugt(*RHSCst));
+      return APInt(DstSizeInBits, Sign * LHSCst->ugt(*RHSCst));
     case CmpInst::Predicate::ICMP_UGE:
-      return APInt(/*numBits=*/1, LHSCst->uge(*RHSCst));
+      return APInt(DstSizeInBits, Sign * LHSCst->uge(*RHSCst));
     case CmpInst::Predicate::ICMP_ULT:
-      return APInt(/*numBits=*/1, LHSCst->ult(*RHSCst));
+      return APInt(DstSizeInBits, Sign * LHSCst->ult(*RHSCst));
     case CmpInst::Predicate::ICMP_ULE:
-      return APInt(/*numBits=*/1, LHSCst->ule(*RHSCst));
+      return APInt(DstSizeInBits, Sign * LHSCst->ule(*RHSCst));
     case CmpInst::Predicate::ICMP_SGT:
-      return APInt(/*numBits=*/1, LHSCst->sgt(*RHSCst));
+      return APInt(DstSizeInBits, Sign * LHSCst->sgt(*RHSCst));
     case CmpInst::Predicate::ICMP_SGE:
-      return APInt(/*numBits=*/1, LHSCst->sge(*RHSCst));
+      return APInt(DstSizeInBits, Sign * LHSCst->sge(*RHSCst));
     case CmpInst::Predicate::ICMP_SLT:
-      return APInt(/*numBits=*/1, LHSCst->slt(*RHSCst));
+      return APInt(DstSizeInBits, Sign * LHSCst->slt(*RHSCst));
     case CmpInst::Predicate::ICMP_SLE:
-      return APInt(/*numBits=*/1, LHSCst->sle(*RHSCst));
+      return APInt(DstSizeInBits, Sign * LHSCst->sle(*RHSCst));
     default:
       return std::nullopt;
     }
diff --git a/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp b/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp
index cd6e32311a9ee..0f479f295601b 100644
--- a/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp
@@ -500,6 +500,18 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
     EXPECT_TRUE(I->getOperand(1).getCImm()->getZExtValue());
   }
 
+  {
+    auto I = CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, s32, One, One);
+    EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_CONSTANT);
+    EXPECT_TRUE(I->getOperand(1).getCImm()->getZExtValue());
+  }
+
+  {
+    auto I = CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, s32, One, Two);
+    EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_CONSTANT);
+    EXPECT_FALSE(I->getOperand(1).getCImm()->getZExtValue());
+  }
+
   LLT VecTy = LLT::fixed_vector(2, s32);
   LLT DstTy = LLT::fixed_vector(2, s1);
   auto Three = CSEB.buildConstant(s32, 3);



More information about the llvm-commits mailing list