[llvm] 42f0010 - [DAG] Constant fold + canonicalize integer binops before SimplifyVBinOp call

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 17 04:02:36 PST 2021


Author: Simon Pilgrim
Date: 2021-12-17T12:02:27Z
New Revision: 42f00106b7fea55495a89013ea6ba58e03ded936

URL: https://github.com/llvm/llvm-project/commit/42f00106b7fea55495a89013ea6ba58e03ded936
DIFF: https://github.com/llvm/llvm-project/commit/42f00106b7fea55495a89013ea6ba58e03ded936.diff

LOG: [DAG] Constant fold + canonicalize integer binops before SimplifyVBinOp call

SimplifyVBinOp still has a FoldConstantArithmetic call, which now it isn't vector specific we should be able to remove (once fp binops are tidied up); but we can at least clean up the integer opcodes to perform the basic constant/undef handling in common code first.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 912b5985c0e4f..4637ce56ff258 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2260,6 +2260,21 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {
   EVT VT = N0.getValueType();
   SDLoc DL(N);
 
+  // fold (add x, undef) -> undef
+  if (N0.isUndef())
+    return N0;
+  if (N1.isUndef())
+    return N1;
+
+  // fold (add c1, c2) -> c1+c2
+  if (SDValue C = DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0, N1}))
+    return C;
+
+  // canonicalize constant to RHS
+  if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
+      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
+    return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
+
   // fold vector ops
   if (VT.isVector()) {
     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
@@ -2268,23 +2283,6 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {
     // fold (add x, 0) -> x, vector edition
     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
       return N0;
-    if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
-      return N1;
-  }
-
-  // fold (add x, undef) -> undef
-  if (N0.isUndef())
-    return N0;
-
-  if (N1.isUndef())
-    return N1;
-
-  if (DAG.isConstantIntBuildVectorOrConstantInt(N0)) {
-    // canonicalize constant to RHS
-    if (!DAG.isConstantIntBuildVectorOrConstantInt(N1))
-      return DAG.getNode(ISD::ADD, DL, VT, N1, N0);
-    // fold (add c1, c2) -> c1+c2
-    return DAG.FoldConstantArithmetic(ISD::ADD, DL, VT, {N0, N1});
   }
 
   // fold (add x, 0) -> x
@@ -3260,6 +3258,15 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
   EVT VT = N0.getValueType();
   SDLoc DL(N);
 
+  // fold (sub x, x) -> 0
+  // FIXME: Refactor this and xor and other similar operations together.
+  if (N0 == N1)
+    return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
+
+  // fold (sub c1, c2) -> c3
+  if (SDValue C = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N1}))
+    return C;
+
   // fold vector ops
   if (VT.isVector()) {
     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
@@ -3270,15 +3277,6 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
       return N0;
   }
 
-  // fold (sub x, x) -> 0
-  // FIXME: Refactor this and xor and other similar operations together.
-  if (N0 == N1)
-    return tryFoldToZero(DL, TLI, VT, DAG, LegalOperations);
-
-  // fold (sub c1, c2) -> c3
-  if (SDValue C = DAG.FoldConstantArithmetic(ISD::SUB, DL, VT, {N0, N1}))
-    return C;
-
   if (SDValue NewSel = foldBinOpIntoSelect(N))
     return NewSel;
 
@@ -3781,6 +3779,15 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
   if (N0.isUndef() || N1.isUndef())
     return DAG.getConstant(0, SDLoc(N), VT);
 
+  // fold (mul c1, c2) -> c1*c2
+  if (SDValue C = DAG.FoldConstantArithmetic(ISD::MUL, SDLoc(N), VT, {N0, N1}))
+    return C;
+
+  // canonicalize constant to RHS (vector doesn't have to splat)
+  if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
+      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
+    return DAG.getNode(ISD::MUL, SDLoc(N), VT, N1, N0);
+
   bool N1IsConst = false;
   bool N1IsOpaqueConst = false;
   APInt ConstValue1;
@@ -3802,15 +3809,6 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
     }
   }
 
-  // fold (mul c1, c2) -> c1*c2
-  if (SDValue C = DAG.FoldConstantArithmetic(ISD::MUL, SDLoc(N), VT, {N0, N1}))
-    return C;
-
-  // canonicalize constant to RHS (vector doesn't have to splat)
-  if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
-     !DAG.isConstantIntBuildVectorOrConstantInt(N1))
-    return DAG.getNode(ISD::MUL, SDLoc(N), VT, N1, N0);
-
   // fold (mul x, 0) -> 0
   if (N1IsConst && ConstValue1.isZero())
     return N1;
@@ -4140,17 +4138,17 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
   EVT CCVT = getSetCCResultType(VT);
   SDLoc DL(N);
 
+  // fold (sdiv c1, c2) -> c1/c2
+  if (SDValue C = DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, {N0, N1}))
+    return C;
+
   // fold vector ops
   if (VT.isVector())
     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
       return FoldedVOp;
 
-  // fold (sdiv c1, c2) -> c1/c2
-  ConstantSDNode *N1C = isConstOrConstSplat(N1);
-  if (SDValue C = DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, {N0, N1}))
-    return C;
-
   // fold (sdiv X, -1) -> 0-X
+  ConstantSDNode *N1C = isConstOrConstSplat(N1);
   if (N1C && N1C->isAllOnes())
     return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0);
 
@@ -4284,17 +4282,17 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) {
   EVT CCVT = getSetCCResultType(VT);
   SDLoc DL(N);
 
+  // fold (udiv c1, c2) -> c1/c2
+  if (SDValue C = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, {N0, N1}))
+    return C;
+
   // fold vector ops
   if (VT.isVector())
     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
       return FoldedVOp;
 
-  // fold (udiv c1, c2) -> c1/c2
-  ConstantSDNode *N1C = isConstOrConstSplat(N1);
-  if (SDValue C = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, {N0, N1}))
-    return C;
-
   // fold (udiv X, -1) -> select(X == -1, 1, 0)
+  ConstantSDNode *N1C = isConstOrConstSplat(N1);
   if (N1C && N1C->isAllOnes())
     return DAG.getSelect(DL, VT, DAG.getSetCC(DL, CCVT, N0, N1, ISD::SETEQ),
                          DAG.getConstant(1, DL, VT),
@@ -4463,6 +4461,15 @@ SDValue DAGCombiner::visitMULHS(SDNode *N) {
   EVT VT = N->getValueType(0);
   SDLoc DL(N);
 
+  // fold (mulhs c1, c2)
+  if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHS, DL, VT, {N0, N1}))
+    return C;
+
+  // canonicalize constant to RHS.
+  if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
+      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
+    return DAG.getNode(ISD::MULHS, DL, N->getVTList(), N1, N0);
+
   if (VT.isVector()) {
     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
       return FoldedVOp;
@@ -4474,15 +4481,6 @@ SDValue DAGCombiner::visitMULHS(SDNode *N) {
       return DAG.getConstant(0, DL, VT);
   }
 
-  // fold (mulhs c1, c2)
-  if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHS, DL, VT, {N0, N1}))
-    return C;
-
-  // canonicalize constant to RHS.
-  if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
-      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
-    return DAG.getNode(ISD::MULHS, DL, N->getVTList(), N1, N0);
-
   // fold (mulhs x, 0) -> 0
   if (isNullConstant(N1))
     return N1;
@@ -4523,6 +4521,15 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) {
   EVT VT = N->getValueType(0);
   SDLoc DL(N);
 
+  // fold (mulhu c1, c2)
+  if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHU, DL, VT, {N0, N1}))
+    return C;
+
+  // canonicalize constant to RHS.
+  if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
+      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
+    return DAG.getNode(ISD::MULHU, DL, N->getVTList(), N1, N0);
+
   if (VT.isVector()) {
     if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
       return FoldedVOp;
@@ -4534,15 +4541,6 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) {
       return DAG.getConstant(0, DL, VT);
   }
 
-  // fold (mulhu c1, c2)
-  if (SDValue C = DAG.FoldConstantArithmetic(ISD::MULHU, DL, VT, {N0, N1}))
-    return C;
-
-  // canonicalize constant to RHS.
-  if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
-      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
-    return DAG.getNode(ISD::MULHU, DL, N->getVTList(), N1, N0);
-
   // fold (mulhu x, 0) -> 0
   if (isNullConstant(N1))
     return N1;
@@ -4905,11 +4903,6 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
   unsigned Opcode = N->getOpcode();
   SDLoc DL(N);
 
-  // fold vector ops
-  if (VT.isVector())
-    if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
-      return FoldedVOp;
-
   // fold operation with constant operands.
   if (SDValue C = DAG.FoldConstantArithmetic(Opcode, DL, VT, {N0, N1}))
     return C;
@@ -4917,7 +4910,12 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
   // canonicalize constant to RHS
   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
       !DAG.isConstantIntBuildVectorOrConstantInt(N1))
-    return DAG.getNode(N->getOpcode(), DL, VT, N1, N0);
+    return DAG.getNode(Opcode, DL, VT, N1, N0);
+
+  // fold vector ops
+  if (VT.isVector())
+    if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
+      return FoldedVOp;
 
   // Is sign bits are zero, flip between UMIN/UMAX and SMIN/SMAX.
   // Only do this if the current op isn't legal and the flipped is.
@@ -5790,6 +5788,15 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
   if (N0 == N1)
     return N0;
 
+  // fold (and c1, c2) -> c1&c2
+  if (SDValue C = DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, {N0, N1}))
+    return C;
+
+  // canonicalize constant to RHS
+  if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
+      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
+    return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0);
+
   // fold vector ops
   if (VT.isVector()) {
     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
@@ -5837,22 +5844,13 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
     }
   }
 
-  // fold (and c1, c2) -> c1&c2
-  ConstantSDNode *N1C = isConstOrConstSplat(N1);
-  if (SDValue C = DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, {N0, N1}))
-    return C;
-
-  // canonicalize constant to RHS
-  if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
-      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
-    return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0);
-
   // fold (and x, -1) -> x
   if (isAllOnesConstant(N1))
     return N0;
 
   // if (and x, c) is known to be zero, return 0
   unsigned BitWidth = VT.getScalarSizeInBits();
+  ConstantSDNode *N1C = isConstOrConstSplat(N1);
   if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(BitWidth)))
     return DAG.getConstant(0, SDLoc(N), VT);
 
@@ -6559,21 +6557,25 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
   if (N0 == N1)
     return N0;
 
+  // fold (or c1, c2) -> c1|c2
+  if (SDValue C = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N), VT, {N0, N1}))
+    return C;
+
+  // canonicalize constant to RHS
+  if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
+      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
+    return DAG.getNode(ISD::OR, SDLoc(N), VT, N1, N0);
+
   // fold vector ops
   if (VT.isVector()) {
     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
       return FoldedVOp;
 
     // fold (or x, 0) -> x, vector edition
-    if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
-      return N1;
     if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
       return N0;
 
     // fold (or x, -1) -> -1, vector edition
-    if (ISD::isConstantSplatVectorAllOnes(N0.getNode()))
-      // do not return N0, because undef node may exist in N0
-      return DAG.getAllOnesConstant(SDLoc(N), N0.getValueType());
     if (ISD::isConstantSplatVectorAllOnes(N1.getNode()))
       // do not return N1, because undef node may exist in N1
       return DAG.getAllOnesConstant(SDLoc(N), N1.getValueType());
@@ -6642,16 +6644,6 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
     }
   }
 
-  // fold (or c1, c2) -> c1|c2
-  ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
-  if (SDValue C = DAG.FoldConstantArithmetic(ISD::OR, SDLoc(N), VT, {N0, N1}))
-    return C;
-
-  // canonicalize constant to RHS
-  if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
-     !DAG.isConstantIntBuildVectorOrConstantInt(N1))
-    return DAG.getNode(ISD::OR, SDLoc(N), VT, N1, N0);
-
   // fold (or x, 0) -> x
   if (isNullConstant(N1))
     return N0;
@@ -6664,6 +6656,7 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
     return NewSel;
 
   // fold (or x, c) -> c iff (x & ~c) == 0
+  ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
   if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
     return N1;
 
@@ -7954,18 +7947,6 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
   EVT VT = N0.getValueType();
   SDLoc DL(N);
 
-  // fold vector ops
-  if (VT.isVector()) {
-    if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
-      return FoldedVOp;
-
-    // fold (xor x, 0) -> x, vector edition
-    if (ISD::isConstantSplatVectorAllZeros(N0.getNode()))
-      return N1;
-    if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
-      return N0;
-  }
-
   // fold (xor undef, undef) -> 0. This is a common idiom (misuse).
   if (N0.isUndef() && N1.isUndef())
     return DAG.getConstant(0, DL, VT);
@@ -7982,9 +7963,19 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
 
   // canonicalize constant to RHS
   if (DAG.isConstantIntBuildVectorOrConstantInt(N0) &&
-     !DAG.isConstantIntBuildVectorOrConstantInt(N1))
+      !DAG.isConstantIntBuildVectorOrConstantInt(N1))
     return DAG.getNode(ISD::XOR, DL, VT, N1, N0);
 
+  // fold vector ops
+  if (VT.isVector()) {
+    if (SDValue FoldedVOp = SimplifyVBinOp(N, DL))
+      return FoldedVOp;
+
+    // fold (xor x, 0) -> x, vector edition
+    if (ISD::isConstantSplatVectorAllZeros(N1.getNode()))
+      return N0;
+  }
+
   // fold (xor x, 0) -> x
   if (isNullConstant(N1))
     return N0;
@@ -8422,6 +8413,10 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
   EVT ShiftVT = N1.getValueType();
   unsigned OpSizeInBits = VT.getScalarSizeInBits();
 
+  // fold (shl c1, c2) -> c1<<c2
+  if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N0, N1}))
+    return C;
+
   // fold vector ops
   if (VT.isVector()) {
     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
@@ -8447,12 +8442,6 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
     }
   }
 
-  ConstantSDNode *N1C = isConstOrConstSplat(N1);
-
-  // fold (shl c1, c2) -> c1<<c2
-  if (SDValue C = DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N), VT, {N0, N1}))
-    return C;
-
   if (SDValue NewSel = foldBinOpIntoSelect(N))
     return NewSel;
 
@@ -8571,6 +8560,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
   // fold (shl (sr[la] exact X,  C1), C2) -> (shl    X, (C2-C1)) if C1 <= C2
   // fold (shl (sr[la] exact X,  C1), C2) -> (sr[la] X, (C2-C1)) if C1  > C2
   // TODO - support non-uniform vector shift amounts.
+  ConstantSDNode *N1C = isConstOrConstSplat(N1);
   if (N1C && (N0.getOpcode() == ISD::SRL || N0.getOpcode() == ISD::SRA) &&
       N0->getFlags().hasExact()) {
     if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
@@ -8771,6 +8761,10 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
   EVT VT = N0.getValueType();
   unsigned OpSizeInBits = VT.getScalarSizeInBits();
 
+  // fold (sra c1, c2) -> (sra c1, c2)
+  if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, {N0, N1}))
+    return C;
+
   // Arithmetic shifting an all-sign-bit value is a no-op.
   // fold (sra 0, x) -> 0
   // fold (sra -1, x) -> -1
@@ -8782,17 +8776,12 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
       return FoldedVOp;
 
-  ConstantSDNode *N1C = isConstOrConstSplat(N1);
-
-  // fold (sra c1, c2) -> (sra c1, c2)
-  if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, {N0, N1}))
-    return C;
-
   if (SDValue NewSel = foldBinOpIntoSelect(N))
     return NewSel;
 
   // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target supports
   // sext_inreg.
+  ConstantSDNode *N1C = isConstOrConstSplat(N1);
   if (N1C && N0.getOpcode() == ISD::SHL && N1 == N0.getOperand(1)) {
     unsigned LowBits = OpSizeInBits - (unsigned)N1C->getZExtValue();
     EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), LowBits);
@@ -8975,21 +8964,20 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
   EVT VT = N0.getValueType();
   unsigned OpSizeInBits = VT.getScalarSizeInBits();
 
+  // fold (srl c1, c2) -> c1 >>u c2
+  if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, {N0, N1}))
+    return C;
+
   // fold vector ops
   if (VT.isVector())
     if (SDValue FoldedVOp = SimplifyVBinOp(N, SDLoc(N)))
       return FoldedVOp;
 
-  ConstantSDNode *N1C = isConstOrConstSplat(N1);
-
-  // fold (srl c1, c2) -> c1 >>u c2
-  if (SDValue C = DAG.FoldConstantArithmetic(ISD::SRL, SDLoc(N), VT, {N0, N1}))
-    return C;
-
   if (SDValue NewSel = foldBinOpIntoSelect(N))
     return NewSel;
 
   // if (srl x, c) is known to be zero, return 0
+  ConstantSDNode *N1C = isConstOrConstSplat(N1);
   if (N1C &&
       DAG.MaskedValueIsZero(SDValue(N, 0), APInt::getAllOnes(OpSizeInBits)))
     return DAG.getConstant(0, SDLoc(N), VT);
@@ -22600,6 +22588,7 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N, const SDLoc &DL) {
   SDNodeFlags Flags = N->getFlags();
 
   // See if we can constant fold the vector operation.
+  // TODO: This isn't vector-specific any more - remove me.
   if (SDValue Fold = DAG.FoldConstantArithmetic(Opcode, SDLoc(LHS),
                                                 LHS.getValueType(), Ops))
     return Fold;


        


More information about the llvm-commits mailing list