[llvm] d30e941 - [DAG] Add SelectionDAG::getShiftAmountConstant APInt variant (#81484)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 13 00:06:20 PST 2024
Author: Simon Pilgrim
Date: 2024-02-13T08:06:16Z
New Revision: d30e941a03f7e70fb7875ede7b5d80342982e3a8
URL: https://github.com/llvm/llvm-project/commit/d30e941a03f7e70fb7875ede7b5d80342982e3a8
DIFF: https://github.com/llvm/llvm-project/commit/d30e941a03f7e70fb7875ede7b5d80342982e3a8.diff
LOG: [DAG] Add SelectionDAG::getShiftAmountConstant APInt variant (#81484)
Asserts that the shift amount is in range and update ExpandShiftByConstant to use getShiftAmountConstant (and legal shift amount types).
Added:
Modified:
llvm/include/llvm/CodeGen/SelectionDAG.h
llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 886ec0b7940ca8..7bb12d8f065c9d 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -668,6 +668,8 @@ class SelectionDAG {
bool isTarget = false);
SDValue getShiftAmountConstant(uint64_t Val, EVT VT, const SDLoc &DL,
bool LegalTypes = true);
+ SDValue getShiftAmountConstant(const APInt &Val, EVT VT, const SDLoc &DL,
+ bool LegalTypes = true);
SDValue getVectorIdxConstant(uint64_t Val, const SDLoc &DL,
bool isTarget = false);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 39b7e061554141..e73a0921a46f5d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -2824,25 +2824,26 @@ void DAGTypeLegalizer::ExpandShiftByConstant(SDNode *N, const APInt &Amt,
EVT NVT = InL.getValueType();
unsigned VTBits = N->getValueType(0).getSizeInBits();
unsigned NVTBits = NVT.getSizeInBits();
- EVT ShTy = N->getOperand(1).getValueType();
if (N->getOpcode() == ISD::SHL) {
if (Amt.uge(VTBits)) {
Lo = Hi = DAG.getConstant(0, DL, NVT);
} else if (Amt.ugt(NVTBits)) {
Lo = DAG.getConstant(0, DL, NVT);
- Hi = DAG.getNode(ISD::SHL, DL,
- NVT, InL, DAG.getConstant(Amt - NVTBits, DL, ShTy));
+ Hi = DAG.getNode(ISD::SHL, DL, NVT, InL,
+ DAG.getShiftAmountConstant(Amt - NVTBits, NVT, DL));
} else if (Amt == NVTBits) {
Lo = DAG.getConstant(0, DL, NVT);
Hi = InL;
} else {
- Lo = DAG.getNode(ISD::SHL, DL, NVT, InL, DAG.getConstant(Amt, DL, ShTy));
- Hi = DAG.getNode(ISD::OR, DL, NVT,
- DAG.getNode(ISD::SHL, DL, NVT, InH,
- DAG.getConstant(Amt, DL, ShTy)),
- DAG.getNode(ISD::SRL, DL, NVT, InL,
- DAG.getConstant(-Amt + NVTBits, DL, ShTy)));
+ Lo = DAG.getNode(ISD::SHL, DL, NVT, InL,
+ DAG.getShiftAmountConstant(Amt, NVT, DL));
+ Hi = DAG.getNode(
+ ISD::OR, DL, NVT,
+ DAG.getNode(ISD::SHL, DL, NVT, InH,
+ DAG.getShiftAmountConstant(Amt, NVT, DL)),
+ DAG.getNode(ISD::SRL, DL, NVT, InL,
+ DAG.getShiftAmountConstant(-Amt + NVTBits, NVT, DL)));
}
return;
}
@@ -2851,19 +2852,21 @@ void DAGTypeLegalizer::ExpandShiftByConstant(SDNode *N, const APInt &Amt,
if (Amt.uge(VTBits)) {
Lo = Hi = DAG.getConstant(0, DL, NVT);
} else if (Amt.ugt(NVTBits)) {
- Lo = DAG.getNode(ISD::SRL, DL,
- NVT, InH, DAG.getConstant(Amt - NVTBits, DL, ShTy));
+ Lo = DAG.getNode(ISD::SRL, DL, NVT, InH,
+ DAG.getShiftAmountConstant(Amt - NVTBits, NVT, DL));
Hi = DAG.getConstant(0, DL, NVT);
} else if (Amt == NVTBits) {
Lo = InH;
Hi = DAG.getConstant(0, DL, NVT);
} else {
- Lo = DAG.getNode(ISD::OR, DL, NVT,
- DAG.getNode(ISD::SRL, DL, NVT, InL,
- DAG.getConstant(Amt, DL, ShTy)),
- DAG.getNode(ISD::SHL, DL, NVT, InH,
- DAG.getConstant(-Amt + NVTBits, DL, ShTy)));
- Hi = DAG.getNode(ISD::SRL, DL, NVT, InH, DAG.getConstant(Amt, DL, ShTy));
+ Lo = DAG.getNode(
+ ISD::OR, DL, NVT,
+ DAG.getNode(ISD::SRL, DL, NVT, InL,
+ DAG.getShiftAmountConstant(Amt, NVT, DL)),
+ DAG.getNode(ISD::SHL, DL, NVT, InH,
+ DAG.getShiftAmountConstant(-Amt + NVTBits, NVT, DL)));
+ Hi = DAG.getNode(ISD::SRL, DL, NVT, InH,
+ DAG.getShiftAmountConstant(Amt, NVT, DL));
}
return;
}
@@ -2871,23 +2874,25 @@ void DAGTypeLegalizer::ExpandShiftByConstant(SDNode *N, const APInt &Amt,
assert(N->getOpcode() == ISD::SRA && "Unknown shift!");
if (Amt.uge(VTBits)) {
Hi = Lo = DAG.getNode(ISD::SRA, DL, NVT, InH,
- DAG.getConstant(NVTBits - 1, DL, ShTy));
+ DAG.getShiftAmountConstant(NVTBits - 1, NVT, DL));
} else if (Amt.ugt(NVTBits)) {
Lo = DAG.getNode(ISD::SRA, DL, NVT, InH,
- DAG.getConstant(Amt - NVTBits, DL, ShTy));
+ DAG.getShiftAmountConstant(Amt - NVTBits, NVT, DL));
Hi = DAG.getNode(ISD::SRA, DL, NVT, InH,
- DAG.getConstant(NVTBits - 1, DL, ShTy));
+ DAG.getShiftAmountConstant(NVTBits - 1, NVT, DL));
} else if (Amt == NVTBits) {
Lo = InH;
Hi = DAG.getNode(ISD::SRA, DL, NVT, InH,
- DAG.getConstant(NVTBits - 1, DL, ShTy));
+ DAG.getShiftAmountConstant(NVTBits - 1, NVT, DL));
} else {
- Lo = DAG.getNode(ISD::OR, DL, NVT,
- DAG.getNode(ISD::SRL, DL, NVT, InL,
- DAG.getConstant(Amt, DL, ShTy)),
- DAG.getNode(ISD::SHL, DL, NVT, InH,
- DAG.getConstant(-Amt + NVTBits, DL, ShTy)));
- Hi = DAG.getNode(ISD::SRA, DL, NVT, InH, DAG.getConstant(Amt, DL, ShTy));
+ Lo = DAG.getNode(
+ ISD::OR, DL, NVT,
+ DAG.getNode(ISD::SRL, DL, NVT, InL,
+ DAG.getShiftAmountConstant(Amt, NVT, DL)),
+ DAG.getNode(ISD::SHL, DL, NVT, InH,
+ DAG.getShiftAmountConstant(-Amt + NVTBits, NVT, DL)));
+ Hi = DAG.getNode(ISD::SRA, DL, NVT, InH,
+ DAG.getShiftAmountConstant(Amt, NVT, DL));
}
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 55eee780d512c8..421bb516ad242f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1734,6 +1734,12 @@ SDValue SelectionDAG::getShiftAmountConstant(uint64_t Val, EVT VT,
return getConstant(Val, DL, ShiftVT);
}
+SDValue SelectionDAG::getShiftAmountConstant(const APInt &Val, EVT VT,
+ const SDLoc &DL, bool LegalTypes) {
+ assert(Val.ult(VT.getScalarSizeInBits()) && "Out of range shift");
+ return getShiftAmountConstant(Val.getZExtValue(), VT, DL, LegalTypes);
+}
+
SDValue SelectionDAG::getVectorIdxConstant(uint64_t Val, const SDLoc &DL,
bool isTarget) {
return getConstant(Val, DL, TLI->getVectorIdxTy(getDataLayout()), isTarget);
More information about the llvm-commits
mailing list