[llvm] 0926d94 - [GlobalISel] Take the result size into account when const folding icmp (#134365)

via llvm-commits llvm-commits at lists.llvm.org
Mon May 5 10:01:57 PDT 2025


Author: KRM7
Date: 2025-05-05T19:01:53+02:00
New Revision: 0926d94453096a37ec324a51d88ec203baedc3ea

URL: https://github.com/llvm/llvm-project/commit/0926d94453096a37ec324a51d88ec203baedc3ea
DIFF: https://github.com/llvm/llvm-project/commit/0926d94453096a37ec324a51d88ec203baedc3ea.diff

LOG: [GlobalISel] Take the result size into account when const folding icmp (#134365)

The current implementation always creates a 1 bit constant for the
result of the `G_ICMP`, which will cause issues if the destination
register size is larger than that. With asserts enabled, it will cause a
crash in `buildConstant`:
```
llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp:322: virtual MachineInstrBuilder llvm::MachineIRBuilder::buildConstant(const DstOp &, const ConstantInt &): Assertion `EltTy.getScalarSizeInBits() == Val.getBitWidth() && "creating constant with the wrong size"' failed. 
```

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/GlobalISel/Utils.h
    llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
    llvm/lib/CodeGen/GlobalISel/Utils.cpp
    llvm/unittests/CodeGen/GlobalISel/CSETest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
index 44141844f42f4..35d21aa1d66d9 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
@@ -325,6 +325,7 @@ ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
 
 std::optional<SmallVector<APInt>>
 ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
+                 unsigned DstScalarSizeInBits, unsigned ExtOp,
                  const MachineRegisterInfo &MRI);
 
 /// Test if the given value is known to have exactly one bit set. This 
diff ers

diff  --git a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
index bf8e847011d7c..10c72641ce2df 100644
--- a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
@@ -189,10 +189,12 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc,
     assert(SrcOps.size() == 3 && "Invalid sources");
     assert(DstOps.size() == 1 && "Invalid dsts");
     LLT SrcTy = SrcOps[1].getLLTTy(*getMRI());
+    LLT DstTy = DstOps[0].getLLTTy(*getMRI());
+    auto BoolExtOp = getBoolExtOp(SrcTy.isVector(), false);
 
-    if (std::optional<SmallVector<APInt>> Cst =
-            ConstantFoldICmp(SrcOps[0].getPredicate(), SrcOps[1].getReg(),
-                             SrcOps[2].getReg(), *getMRI())) {
+    if (std::optional<SmallVector<APInt>> Cst = ConstantFoldICmp(
+            SrcOps[0].getPredicate(), SrcOps[1].getReg(), SrcOps[2].getReg(),
+            DstTy.getScalarSizeInBits(), BoolExtOp, *getMRI())) {
       if (SrcTy.isVector())
         return buildBuildVectorConstant(DstOps[0], *Cst);
       return buildConstant(DstOps[0], Cst->front());

diff  --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index d8cc86b34a819..445a716d20478 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -1027,39 +1027,50 @@ llvm::ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
 
 std::optional<SmallVector<APInt>>
 llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
+                       unsigned DstScalarSizeInBits, unsigned ExtOp,
                        const MachineRegisterInfo &MRI) {
-  LLT Ty = MRI.getType(Op1);
-  if (Ty != MRI.getType(Op2))
-    return std::nullopt;
+  assert(ExtOp == TargetOpcode::G_SEXT || ExtOp == TargetOpcode::G_ZEXT ||
+         ExtOp == TargetOpcode::G_ANYEXT);
 
-  auto TryFoldScalar = [&MRI, Pred](Register LHS,
-                                    Register RHS) -> std::optional<APInt> {
-    auto LHSCst = getIConstantVRegVal(LHS, MRI);
+  const LLT Ty = MRI.getType(Op1);
+
+  auto GetICmpResultCst = [&](bool IsTrue) {
+    if (IsTrue)
+      return ExtOp == TargetOpcode::G_SEXT
+                 ? APInt::getAllOnes(DstScalarSizeInBits)
+                 : APInt::getOneBitSet(DstScalarSizeInBits, 0);
+    return APInt::getZero(DstScalarSizeInBits);
+  };
+
+  auto TryFoldScalar = [&](Register LHS, Register RHS) -> std::optional<APInt> {
     auto RHSCst = getIConstantVRegVal(RHS, MRI);
-    if (!LHSCst || !RHSCst)
+    if (!RHSCst)
+      return std::nullopt;
+    auto LHSCst = getIConstantVRegVal(LHS, MRI);
+    if (!LHSCst)
       return std::nullopt;
 
     switch (Pred) {
     case CmpInst::Predicate::ICMP_EQ:
-      return APInt(/*numBits=*/1, LHSCst->eq(*RHSCst));
+      return GetICmpResultCst(LHSCst->eq(*RHSCst));
     case CmpInst::Predicate::ICMP_NE:
-      return APInt(/*numBits=*/1, LHSCst->ne(*RHSCst));
+      return GetICmpResultCst(LHSCst->ne(*RHSCst));
     case CmpInst::Predicate::ICMP_UGT:
-      return APInt(/*numBits=*/1, LHSCst->ugt(*RHSCst));
+      return GetICmpResultCst(LHSCst->ugt(*RHSCst));
     case CmpInst::Predicate::ICMP_UGE:
-      return APInt(/*numBits=*/1, LHSCst->uge(*RHSCst));
+      return GetICmpResultCst(LHSCst->uge(*RHSCst));
     case CmpInst::Predicate::ICMP_ULT:
-      return APInt(/*numBits=*/1, LHSCst->ult(*RHSCst));
+      return GetICmpResultCst(LHSCst->ult(*RHSCst));
     case CmpInst::Predicate::ICMP_ULE:
-      return APInt(/*numBits=*/1, LHSCst->ule(*RHSCst));
+      return GetICmpResultCst(LHSCst->ule(*RHSCst));
     case CmpInst::Predicate::ICMP_SGT:
-      return APInt(/*numBits=*/1, LHSCst->sgt(*RHSCst));
+      return GetICmpResultCst(LHSCst->sgt(*RHSCst));
     case CmpInst::Predicate::ICMP_SGE:
-      return APInt(/*numBits=*/1, LHSCst->sge(*RHSCst));
+      return GetICmpResultCst(LHSCst->sge(*RHSCst));
     case CmpInst::Predicate::ICMP_SLT:
-      return APInt(/*numBits=*/1, LHSCst->slt(*RHSCst));
+      return GetICmpResultCst(LHSCst->slt(*RHSCst));
     case CmpInst::Predicate::ICMP_SLE:
-      return APInt(/*numBits=*/1, LHSCst->sle(*RHSCst));
+      return GetICmpResultCst(LHSCst->sle(*RHSCst));
     default:
       return std::nullopt;
     }

diff  --git a/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp b/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp
index cd6e32311a9ee..7c29c9d419c08 100644
--- a/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp
@@ -500,6 +500,18 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
     EXPECT_TRUE(I->getOperand(1).getCImm()->getZExtValue());
   }
 
+  {
+    auto I = CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, s32, One, One);
+    EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_CONSTANT);
+    EXPECT_EQ(I->getOperand(1).getCImm()->getZExtValue(), 1);
+  }
+
+  {
+    auto I = CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, s32, One, Two);
+    EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_CONSTANT);
+    EXPECT_EQ(I->getOperand(1).getCImm()->getZExtValue(), 0);
+  }
+
   LLT VecTy = LLT::fixed_vector(2, s32);
   LLT DstTy = LLT::fixed_vector(2, s1);
   auto Three = CSEB.buildConstant(s32, 3);
@@ -508,6 +520,8 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
   auto OneTwo = CSEB.buildBuildVector(VecTy, {One.getReg(0), Two.getReg(0)});
   auto TwoThree =
       CSEB.buildBuildVector(VecTy, {Two.getReg(0), Three.getReg(0)});
+  auto OneThree =
+      CSEB.buildBuildVector(VecTy, {One.getReg(0), Three.getReg(0)});
   auto MinusOneOne =
       CSEB.buildBuildVector(VecTy, {MinusOne.getReg(0), MinusOne.getReg(0)});
   auto MinusOneTwo =
@@ -547,6 +561,36 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
   // ICMP_SLE
   CSEB.buildICmp(CmpInst::Predicate::ICMP_SLE, DstTy, MinusOneTwo, MinusOneOne);
 
+  {
+    auto I =
+        CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, VecTy, OneOne, TwoThree);
+    EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
+    const APInt HiCst = *getIConstantVRegVal(I->getOperand(1).getReg(), *MRI);
+    const APInt LoCst = *getIConstantVRegVal(I->getOperand(2).getReg(), *MRI);
+    EXPECT_EQ(HiCst.getSExtValue(), 0);
+    EXPECT_EQ(LoCst.getSExtValue(), 0);
+  }
+
+  {
+    auto I =
+        CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, VecTy, OneThree, TwoThree);
+    EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
+    const APInt HiCst = *getIConstantVRegVal(I->getOperand(1).getReg(), *MRI);
+    const APInt LoCst = *getIConstantVRegVal(I->getOperand(2).getReg(), *MRI);
+    EXPECT_EQ(HiCst.getSExtValue(), 0);
+    EXPECT_EQ(LoCst.getSExtValue(), -1);
+  }
+
+  {
+    auto I =
+        CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, VecTy, TwoThree, TwoThree);
+    EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
+    const APInt HiCst = *getIConstantVRegVal(I->getOperand(1).getReg(), *MRI);
+    const APInt LoCst = *getIConstantVRegVal(I->getOperand(2).getReg(), *MRI);
+    EXPECT_EQ(HiCst.getSExtValue(), -1);
+    EXPECT_EQ(LoCst.getSExtValue(), -1);
+  }
+
   auto CheckStr = R"(
   ; CHECK: [[One:%[0-9]+]]:_(s32) = G_CONSTANT i32 1
   ; CHECK: [[Two:%[0-9]+]]:_(s32) = G_CONSTANT i32 2
@@ -558,6 +602,7 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
   ; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[One]]:_(s32), [[One]]:_(s32)
   ; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[One]]:_(s32), [[Two]]:_(s32)
   ; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[Two]]:_(s32), [[Three]]:_(s32)
+  ; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[One]]:_(s32), [[Three]]:_(s32)
   ; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[MinusOne]]:_(s32), [[MinusOne]]:_(s32)
   ; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[MinusOne]]:_(s32), [[MinusTwo]]:_(s32)
   ; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[MinusTwo]]:_(s32), [[MinusThree]]:_(s32)


        


More information about the llvm-commits mailing list