[llvm] 93ec08d - [DAG] Move SIGN_EXTEND_INREG constant folding inside FoldConstantArithmetic
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Sat Oct 19 12:57:40 PDT 2024
Author: Simon Pilgrim
Date: 2024-10-19T20:57:07+01:00
New Revision: 93ec08d62971d51a239fba8468d3cf9cb9e54fb0
URL: https://github.com/llvm/llvm-project/commit/93ec08d62971d51a239fba8468d3cf9cb9e54fb0
DIFF: https://github.com/llvm/llvm-project/commit/93ec08d62971d51a239fba8468d3cf9cb9e54fb0.diff
LOG: [DAG] Move SIGN_EXTEND_INREG constant folding inside FoldConstantArithmetic
Update visitSIGN_EXTEND_INREG to call FoldConstantArithmetic instead of getNode.
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 98eed6b7503d10..c892bdcd7fd837 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -14819,8 +14819,9 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
return DAG.getConstant(0, DL, VT);
// fold (sext_in_reg c1) -> c1
- if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
- return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N0, N1);
+ if (SDValue C =
+ DAG.FoldConstantArithmetic(ISD::SIGN_EXTEND_INREG, DL, VT, {N0, N1}))
+ return C;
// If the input is already sign extended, just drop the extension.
if (ExtVTBits >= DAG.ComputeMaxSignificantBits(N0))
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 43d49674297f6f..55cebc28e49275 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6659,6 +6659,44 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
if (TLI->isCommutativeBinOp(Opcode))
if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Ops[1]))
return FoldSymbolOffset(Opcode, VT, GA, Ops[0].getNode());
+
+ // fold (sext_in_reg c1) -> c2
+ if (Opcode == ISD::SIGN_EXTEND_INREG) {
+ EVT EVT = cast<VTSDNode>(Ops[1])->getVT();
+
+ auto SignExtendInReg = [&](APInt Val, llvm::EVT ConstantVT) {
+ unsigned FromBits = EVT.getScalarSizeInBits();
+ Val <<= Val.getBitWidth() - FromBits;
+ Val.ashrInPlace(Val.getBitWidth() - FromBits);
+ return getConstant(Val, DL, ConstantVT);
+ };
+
+ if (auto *C1 = dyn_cast<ConstantSDNode>(Ops[0])) {
+ const APInt &Val = C1->getAPIntValue();
+ return SignExtendInReg(Val, VT);
+ }
+
+ if (ISD::isBuildVectorOfConstantSDNodes(Ops[0].getNode())) {
+ SmallVector<SDValue, 8> ScalarOps;
+ llvm::EVT OpVT = Ops[0].getOperand(0).getValueType();
+ for (int I = 0, E = VT.getVectorNumElements(); I != E; ++I) {
+ SDValue Op = Ops[0].getOperand(I);
+ if (Op.isUndef()) {
+ ScalarOps.push_back(getUNDEF(OpVT));
+ continue;
+ }
+ APInt Val = cast<ConstantSDNode>(Op)->getAPIntValue();
+ ScalarOps.push_back(SignExtendInReg(Val, OpVT));
+ }
+ return getBuildVector(VT, DL, ScalarOps);
+ }
+
+ if (Ops[0].getOpcode() == ISD::SPLAT_VECTOR &&
+ isa<ConstantSDNode>(Ops[0].getOperand(0)))
+ return getNode(ISD::SPLAT_VECTOR, DL, VT,
+ SignExtendInReg(Ops[0].getConstantOperandAPInt(0),
+ Ops[0].getOperand(0).getValueType()));
+ }
}
// This is for vector folding only from here on.
@@ -7205,41 +7243,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
"Vector element counts must match in SIGN_EXTEND_INREG");
assert(EVT.bitsLE(VT) && "Not extending!");
if (EVT == VT) return N1; // Not actually extending
-
- auto SignExtendInReg = [&](APInt Val, llvm::EVT ConstantVT) {
- unsigned FromBits = EVT.getScalarSizeInBits();
- Val <<= Val.getBitWidth() - FromBits;
- Val.ashrInPlace(Val.getBitWidth() - FromBits);
- return getConstant(Val, DL, ConstantVT);
- };
-
- if (N1C) {
- const APInt &Val = N1C->getAPIntValue();
- return SignExtendInReg(Val, VT);
- }
-
- if (ISD::isBuildVectorOfConstantSDNodes(N1.getNode())) {
- SmallVector<SDValue, 8> Ops;
- llvm::EVT OpVT = N1.getOperand(0).getValueType();
- for (int i = 0, e = VT.getVectorNumElements(); i != e; ++i) {
- SDValue Op = N1.getOperand(i);
- if (Op.isUndef()) {
- Ops.push_back(getUNDEF(OpVT));
- continue;
- }
- ConstantSDNode *C = cast<ConstantSDNode>(Op);
- APInt Val = C->getAPIntValue();
- Ops.push_back(SignExtendInReg(Val, OpVT));
- }
- return getBuildVector(VT, DL, Ops);
- }
-
- if (N1.getOpcode() == ISD::SPLAT_VECTOR &&
- isa<ConstantSDNode>(N1.getOperand(0)))
- return getNode(
- ISD::SPLAT_VECTOR, DL, VT,
- SignExtendInReg(N1.getConstantOperandAPInt(0),
- N1.getOperand(0).getValueType()));
break;
}
case ISD::FP_TO_SINT_SAT:
More information about the llvm-commits
mailing list