[llvm] [GlobalISel] Take the result size into account when const folding icmp (PR #134365)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 4 04:03:32 PDT 2025
https://github.com/KRM7 created https://github.com/llvm/llvm-project/pull/134365
The current implementation always creates a 1 bit constant for the result of the `G_ICMP`, which will cause issues if the destination register size is larger than that. With asserts enabled, it will cause a crash in `buildConstant`:
```
llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp:322: virtual MachineInstrBuilder llvm::MachineIRBuilder::buildConstant(const DstOp &, const ConstantInt &): Assertion `EltTy.getScalarSizeInBits() == Val.getBitWidth() && "creating constant with the wrong size"' failed.
```
>From 929f09caf5de7d6543bb7606696920f76c1316fb Mon Sep 17 00:00:00 2001
From: Krisztian Rugasi <Krisztian.Rugasi at hightec-rt.com>
Date: Fri, 4 Apr 2025 12:56:37 +0200
Subject: [PATCH] [GlobalISel] Take the result size into account when const
folding icmp
---
llvm/include/llvm/CodeGen/GlobalISel/Utils.h | 2 +-
llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp | 7 +++--
llvm/lib/CodeGen/GlobalISel/Utils.cpp | 31 ++++++++++++-------
3 files changed, 24 insertions(+), 16 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
index 44141844f42f4..123ff4f47777a 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
@@ -325,7 +325,7 @@ ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
std::optional<SmallVector<APInt>>
ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
- const MachineRegisterInfo &MRI);
+ Register Dst, unsigned ExtOp, const MachineRegisterInfo &MRI);
/// Test if the given value is known to have exactly one bit set. This differs
/// from computeKnownBits in that it doesn't necessarily determine which bit is
diff --git a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
index bf8e847011d7c..5bb4df2b46769 100644
--- a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
@@ -190,9 +190,10 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc,
assert(DstOps.size() == 1 && "Invalid dsts");
LLT SrcTy = SrcOps[1].getLLTTy(*getMRI());
- if (std::optional<SmallVector<APInt>> Cst =
- ConstantFoldICmp(SrcOps[0].getPredicate(), SrcOps[1].getReg(),
- SrcOps[2].getReg(), *getMRI())) {
+ if (std::optional<SmallVector<APInt>> Cst = ConstantFoldICmp(
+ SrcOps[0].getPredicate(), SrcOps[1].getReg(), SrcOps[2].getReg(),
+ DstOps[0].getReg(), getBoolExtOp(SrcTy.isVector(), false),
+ *getMRI())) {
if (SrcTy.isVector())
return buildBuildVectorConstant(DstOps[0], *Cst);
return buildConstant(DstOps[0], Cst->front());
diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index 223d69c362185..5db6ddb7405a6 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -1027,13 +1027,20 @@ llvm::ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
std::optional<SmallVector<APInt>>
llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
+ Register Dst, unsigned ExtOp,
const MachineRegisterInfo &MRI) {
+ assert(ExtOp == TargetOpcode::G_SEXT || ExtOp == TargetOpcode::G_ZEXT ||
+ ExtOp == TargetOpcode::G_ANYEXT);
+
LLT Ty = MRI.getType(Op1);
if (Ty != MRI.getType(Op2))
return std::nullopt;
- auto TryFoldScalar = [&MRI, Pred](Register LHS,
- Register RHS) -> std::optional<APInt> {
+ const uint64_t DstSize = MRI.getType(Dst).getScalarSizeInBits();
+ const int64_t Sign = ExtOp == TargetOpcode::G_SEXT ? -1 : 1;
+
+ auto TryFoldScalar = [&MRI, Pred, DstSize, Sign](
+ Register LHS, Register RHS) -> std::optional<APInt> {
auto LHSCst = getIConstantVRegVal(LHS, MRI);
auto RHSCst = getIConstantVRegVal(RHS, MRI);
if (!LHSCst || !RHSCst)
@@ -1041,25 +1048,25 @@ llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
switch (Pred) {
case CmpInst::Predicate::ICMP_EQ:
- return APInt(/*numBits=*/1, LHSCst->eq(*RHSCst));
+ return APInt(DstSize, Sign * LHSCst->eq(*RHSCst));
case CmpInst::Predicate::ICMP_NE:
- return APInt(/*numBits=*/1, LHSCst->ne(*RHSCst));
+ return APInt(DstSize, Sign * LHSCst->ne(*RHSCst));
case CmpInst::Predicate::ICMP_UGT:
- return APInt(/*numBits=*/1, LHSCst->ugt(*RHSCst));
+ return APInt(DstSize, Sign * LHSCst->ugt(*RHSCst));
case CmpInst::Predicate::ICMP_UGE:
- return APInt(/*numBits=*/1, LHSCst->uge(*RHSCst));
+ return APInt(DstSize, Sign * LHSCst->uge(*RHSCst));
case CmpInst::Predicate::ICMP_ULT:
- return APInt(/*numBits=*/1, LHSCst->ult(*RHSCst));
+ return APInt(DstSize, Sign * LHSCst->ult(*RHSCst));
case CmpInst::Predicate::ICMP_ULE:
- return APInt(/*numBits=*/1, LHSCst->ule(*RHSCst));
+ return APInt(DstSize, Sign * LHSCst->ule(*RHSCst));
case CmpInst::Predicate::ICMP_SGT:
- return APInt(/*numBits=*/1, LHSCst->sgt(*RHSCst));
+ return APInt(DstSize, Sign * LHSCst->sgt(*RHSCst));
case CmpInst::Predicate::ICMP_SGE:
- return APInt(/*numBits=*/1, LHSCst->sge(*RHSCst));
+ return APInt(DstSize, Sign * LHSCst->sge(*RHSCst));
case CmpInst::Predicate::ICMP_SLT:
- return APInt(/*numBits=*/1, LHSCst->slt(*RHSCst));
+ return APInt(DstSize, Sign * LHSCst->slt(*RHSCst));
case CmpInst::Predicate::ICMP_SLE:
- return APInt(/*numBits=*/1, LHSCst->sle(*RHSCst));
+ return APInt(DstSize, Sign * LHSCst->sle(*RHSCst));
default:
return std::nullopt;
}
More information about the llvm-commits
mailing list