[llvm] [DAG] Add SelectionDAG::getShiftAmountConstant APInt variant (PR #81484)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 12 06:24:10 PST 2024


https://github.com/RKSimon created https://github.com/llvm/llvm-project/pull/81484

Asserts that the shift amount is in range and update ExpandShiftByConstant to use getShiftAmountConstant (and legal shift amount types)

>From e638e0c6a4f884426202ba1afb5b5ebd0ecaf009 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Mon, 12 Feb 2024 14:22:09 +0000
Subject: [PATCH] [DAG] Add SelectionDAG::getShiftAmountConstant APInt variant

Asserts that the shift amount is in range and update ExpandShiftByConstant to use getShiftAmountConstant (and legal shift amount types)
---
 llvm/include/llvm/CodeGen/SelectionDAG.h      |  2 +
 .../SelectionDAG/LegalizeIntegerTypes.cpp     | 59 ++++++++++---------
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  6 ++
 3 files changed, 40 insertions(+), 27 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 886ec0b7940ca8..7bb12d8f065c9d 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -668,6 +668,8 @@ class SelectionDAG {
                             bool isTarget = false);
   SDValue getShiftAmountConstant(uint64_t Val, EVT VT, const SDLoc &DL,
                                  bool LegalTypes = true);
+  SDValue getShiftAmountConstant(const APInt &Val, EVT VT, const SDLoc &DL,
+                                 bool LegalTypes = true);
   SDValue getVectorIdxConstant(uint64_t Val, const SDLoc &DL,
                                bool isTarget = false);
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 39b7e061554141..e73a0921a46f5d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -2824,25 +2824,26 @@ void DAGTypeLegalizer::ExpandShiftByConstant(SDNode *N, const APInt &Amt,
   EVT NVT = InL.getValueType();
   unsigned VTBits = N->getValueType(0).getSizeInBits();
   unsigned NVTBits = NVT.getSizeInBits();
-  EVT ShTy = N->getOperand(1).getValueType();
 
   if (N->getOpcode() == ISD::SHL) {
     if (Amt.uge(VTBits)) {
       Lo = Hi = DAG.getConstant(0, DL, NVT);
     } else if (Amt.ugt(NVTBits)) {
       Lo = DAG.getConstant(0, DL, NVT);
-      Hi = DAG.getNode(ISD::SHL, DL,
-                       NVT, InL, DAG.getConstant(Amt - NVTBits, DL, ShTy));
+      Hi = DAG.getNode(ISD::SHL, DL, NVT, InL,
+                       DAG.getShiftAmountConstant(Amt - NVTBits, NVT, DL));
     } else if (Amt == NVTBits) {
       Lo = DAG.getConstant(0, DL, NVT);
       Hi = InL;
     } else {
-      Lo = DAG.getNode(ISD::SHL, DL, NVT, InL, DAG.getConstant(Amt, DL, ShTy));
-      Hi = DAG.getNode(ISD::OR, DL, NVT,
-                       DAG.getNode(ISD::SHL, DL, NVT, InH,
-                                   DAG.getConstant(Amt, DL, ShTy)),
-                       DAG.getNode(ISD::SRL, DL, NVT, InL,
-                                   DAG.getConstant(-Amt + NVTBits, DL, ShTy)));
+      Lo = DAG.getNode(ISD::SHL, DL, NVT, InL,
+                       DAG.getShiftAmountConstant(Amt, NVT, DL));
+      Hi = DAG.getNode(
+          ISD::OR, DL, NVT,
+          DAG.getNode(ISD::SHL, DL, NVT, InH,
+                      DAG.getShiftAmountConstant(Amt, NVT, DL)),
+          DAG.getNode(ISD::SRL, DL, NVT, InL,
+                      DAG.getShiftAmountConstant(-Amt + NVTBits, NVT, DL)));
     }
     return;
   }
@@ -2851,19 +2852,21 @@ void DAGTypeLegalizer::ExpandShiftByConstant(SDNode *N, const APInt &Amt,
     if (Amt.uge(VTBits)) {
       Lo = Hi = DAG.getConstant(0, DL, NVT);
     } else if (Amt.ugt(NVTBits)) {
-      Lo = DAG.getNode(ISD::SRL, DL,
-                       NVT, InH, DAG.getConstant(Amt - NVTBits, DL, ShTy));
+      Lo = DAG.getNode(ISD::SRL, DL, NVT, InH,
+                       DAG.getShiftAmountConstant(Amt - NVTBits, NVT, DL));
       Hi = DAG.getConstant(0, DL, NVT);
     } else if (Amt == NVTBits) {
       Lo = InH;
       Hi = DAG.getConstant(0, DL, NVT);
     } else {
-      Lo = DAG.getNode(ISD::OR, DL, NVT,
-                       DAG.getNode(ISD::SRL, DL, NVT, InL,
-                                   DAG.getConstant(Amt, DL, ShTy)),
-                       DAG.getNode(ISD::SHL, DL, NVT, InH,
-                                   DAG.getConstant(-Amt + NVTBits, DL, ShTy)));
-      Hi = DAG.getNode(ISD::SRL, DL, NVT, InH, DAG.getConstant(Amt, DL, ShTy));
+      Lo = DAG.getNode(
+          ISD::OR, DL, NVT,
+          DAG.getNode(ISD::SRL, DL, NVT, InL,
+                      DAG.getShiftAmountConstant(Amt, NVT, DL)),
+          DAG.getNode(ISD::SHL, DL, NVT, InH,
+                      DAG.getShiftAmountConstant(-Amt + NVTBits, NVT, DL)));
+      Hi = DAG.getNode(ISD::SRL, DL, NVT, InH,
+                       DAG.getShiftAmountConstant(Amt, NVT, DL));
     }
     return;
   }
@@ -2871,23 +2874,25 @@ void DAGTypeLegalizer::ExpandShiftByConstant(SDNode *N, const APInt &Amt,
   assert(N->getOpcode() == ISD::SRA && "Unknown shift!");
   if (Amt.uge(VTBits)) {
     Hi = Lo = DAG.getNode(ISD::SRA, DL, NVT, InH,
-                          DAG.getConstant(NVTBits - 1, DL, ShTy));
+                          DAG.getShiftAmountConstant(NVTBits - 1, NVT, DL));
   } else if (Amt.ugt(NVTBits)) {
     Lo = DAG.getNode(ISD::SRA, DL, NVT, InH,
-                     DAG.getConstant(Amt - NVTBits, DL, ShTy));
+                     DAG.getShiftAmountConstant(Amt - NVTBits, NVT, DL));
     Hi = DAG.getNode(ISD::SRA, DL, NVT, InH,
-                     DAG.getConstant(NVTBits - 1, DL, ShTy));
+                     DAG.getShiftAmountConstant(NVTBits - 1, NVT, DL));
   } else if (Amt == NVTBits) {
     Lo = InH;
     Hi = DAG.getNode(ISD::SRA, DL, NVT, InH,
-                     DAG.getConstant(NVTBits - 1, DL, ShTy));
+                     DAG.getShiftAmountConstant(NVTBits - 1, NVT, DL));
   } else {
-    Lo = DAG.getNode(ISD::OR, DL, NVT,
-                     DAG.getNode(ISD::SRL, DL, NVT, InL,
-                                 DAG.getConstant(Amt, DL, ShTy)),
-                     DAG.getNode(ISD::SHL, DL, NVT, InH,
-                                 DAG.getConstant(-Amt + NVTBits, DL, ShTy)));
-    Hi = DAG.getNode(ISD::SRA, DL, NVT, InH, DAG.getConstant(Amt, DL, ShTy));
+    Lo = DAG.getNode(
+        ISD::OR, DL, NVT,
+        DAG.getNode(ISD::SRL, DL, NVT, InL,
+                    DAG.getShiftAmountConstant(Amt, NVT, DL)),
+        DAG.getNode(ISD::SHL, DL, NVT, InH,
+                    DAG.getShiftAmountConstant(-Amt + NVTBits, NVT, DL)));
+    Hi = DAG.getNode(ISD::SRA, DL, NVT, InH,
+                     DAG.getShiftAmountConstant(Amt, NVT, DL));
   }
 }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 55eee780d512c8..421bb516ad242f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1734,6 +1734,12 @@ SDValue SelectionDAG::getShiftAmountConstant(uint64_t Val, EVT VT,
   return getConstant(Val, DL, ShiftVT);
 }
 
+SDValue SelectionDAG::getShiftAmountConstant(const APInt &Val, EVT VT,
+                                             const SDLoc &DL, bool LegalTypes) {
+  assert(Val.ult(VT.getScalarSizeInBits()) && "Out of range shift");
+  return getShiftAmountConstant(Val.getZExtValue(), VT, DL, LegalTypes);
+}
+
 SDValue SelectionDAG::getVectorIdxConstant(uint64_t Val, const SDLoc &DL,
                                            bool isTarget) {
   return getConstant(Val, DL, TLI->getVectorIdxTy(getDataLayout()), isTarget);



More information about the llvm-commits mailing list