[llvm] 93ec08d - [DAG] Move SIGN_EXTEND_INREG constant folding inside FoldConstantArithmetic

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sat Oct 19 12:57:40 PDT 2024


Author: Simon Pilgrim
Date: 2024-10-19T20:57:07+01:00
New Revision: 93ec08d62971d51a239fba8468d3cf9cb9e54fb0

URL: https://github.com/llvm/llvm-project/commit/93ec08d62971d51a239fba8468d3cf9cb9e54fb0
DIFF: https://github.com/llvm/llvm-project/commit/93ec08d62971d51a239fba8468d3cf9cb9e54fb0.diff

LOG: [DAG] Move SIGN_EXTEND_INREG constant folding inside FoldConstantArithmetic

Update visitSIGN_EXTEND_INREG to call FoldConstantArithmetic instead of getNode.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 98eed6b7503d10..c892bdcd7fd837 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -14819,8 +14819,9 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
     return DAG.getConstant(0, DL, VT);
 
   // fold (sext_in_reg c1) -> c1
-  if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
-    return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N0, N1);
+  if (SDValue C =
+          DAG.FoldConstantArithmetic(ISD::SIGN_EXTEND_INREG, DL, VT, {N0, N1}))
+    return C;
 
   // If the input is already sign extended, just drop the extension.
   if (ExtVTBits >= DAG.ComputeMaxSignificantBits(N0))

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 43d49674297f6f..55cebc28e49275 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6659,6 +6659,44 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
     if (TLI->isCommutativeBinOp(Opcode))
       if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Ops[1]))
         return FoldSymbolOffset(Opcode, VT, GA, Ops[0].getNode());
+
+    // fold (sext_in_reg c1) -> c2
+    if (Opcode == ISD::SIGN_EXTEND_INREG) {
+      EVT EVT = cast<VTSDNode>(Ops[1])->getVT();
+
+      auto SignExtendInReg = [&](APInt Val, llvm::EVT ConstantVT) {
+        unsigned FromBits = EVT.getScalarSizeInBits();
+        Val <<= Val.getBitWidth() - FromBits;
+        Val.ashrInPlace(Val.getBitWidth() - FromBits);
+        return getConstant(Val, DL, ConstantVT);
+      };
+
+      if (auto *C1 = dyn_cast<ConstantSDNode>(Ops[0])) {
+        const APInt &Val = C1->getAPIntValue();
+        return SignExtendInReg(Val, VT);
+      }
+
+      if (ISD::isBuildVectorOfConstantSDNodes(Ops[0].getNode())) {
+        SmallVector<SDValue, 8> ScalarOps;
+        llvm::EVT OpVT = Ops[0].getOperand(0).getValueType();
+        for (int I = 0, E = VT.getVectorNumElements(); I != E; ++I) {
+          SDValue Op = Ops[0].getOperand(I);
+          if (Op.isUndef()) {
+            ScalarOps.push_back(getUNDEF(OpVT));
+            continue;
+          }
+          APInt Val = cast<ConstantSDNode>(Op)->getAPIntValue();
+          ScalarOps.push_back(SignExtendInReg(Val, OpVT));
+        }
+        return getBuildVector(VT, DL, ScalarOps);
+      }
+
+      if (Ops[0].getOpcode() == ISD::SPLAT_VECTOR &&
+          isa<ConstantSDNode>(Ops[0].getOperand(0)))
+        return getNode(ISD::SPLAT_VECTOR, DL, VT,
+                       SignExtendInReg(Ops[0].getConstantOperandAPInt(0),
+                                       Ops[0].getOperand(0).getValueType()));
+    }
   }
 
   // This is for vector folding only from here on.
@@ -7205,41 +7243,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
            "Vector element counts must match in SIGN_EXTEND_INREG");
     assert(EVT.bitsLE(VT) && "Not extending!");
     if (EVT == VT) return N1;  // Not actually extending
-
-    auto SignExtendInReg = [&](APInt Val, llvm::EVT ConstantVT) {
-      unsigned FromBits = EVT.getScalarSizeInBits();
-      Val <<= Val.getBitWidth() - FromBits;
-      Val.ashrInPlace(Val.getBitWidth() - FromBits);
-      return getConstant(Val, DL, ConstantVT);
-    };
-
-    if (N1C) {
-      const APInt &Val = N1C->getAPIntValue();
-      return SignExtendInReg(Val, VT);
-    }
-
-    if (ISD::isBuildVectorOfConstantSDNodes(N1.getNode())) {
-      SmallVector<SDValue, 8> Ops;
-      llvm::EVT OpVT = N1.getOperand(0).getValueType();
-      for (int i = 0, e = VT.getVectorNumElements(); i != e; ++i) {
-        SDValue Op = N1.getOperand(i);
-        if (Op.isUndef()) {
-          Ops.push_back(getUNDEF(OpVT));
-          continue;
-        }
-        ConstantSDNode *C = cast<ConstantSDNode>(Op);
-        APInt Val = C->getAPIntValue();
-        Ops.push_back(SignExtendInReg(Val, OpVT));
-      }
-      return getBuildVector(VT, DL, Ops);
-    }
-
-    if (N1.getOpcode() == ISD::SPLAT_VECTOR &&
-        isa<ConstantSDNode>(N1.getOperand(0)))
-      return getNode(
-          ISD::SPLAT_VECTOR, DL, VT,
-          SignExtendInReg(N1.getConstantOperandAPInt(0),
-                          N1.getOperand(0).getValueType()));
     break;
   }
   case ISD::FP_TO_SINT_SAT:


        


More information about the llvm-commits mailing list