[llvm] 1cc02b0 - [SelectionDAG] Add helper function to check whether a SDValue is neutral element. NFC.

Yeting Kuo via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 29 20:29:23 PDT 2022


Author: Yeting Kuo
Date: 2022-09-30T11:29:11+08:00
New Revision: 1cc02b05b752860637b3241fe944854bec2540c4

URL: https://github.com/llvm/llvm-project/commit/1cc02b05b752860637b3241fe944854bec2540c4
DIFF: https://github.com/llvm/llvm-project/commit/1cc02b05b752860637b3241fe944854bec2540c4.diff

LOG: [SelectionDAG] Add helper function to check whether a SDValue is neutral element. NFC.

Using this helper makes work about neutral elements more easier. Although I only
find one case now, I think it will have more chance to be used since so many
combine works are related to neutral elements.

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D133866

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SelectionDAGNodes.h
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 063f88fe6c560..179578e08b70c 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1689,6 +1689,12 @@ bool isOneConstant(SDValue V);
 /// Returns true if \p V is a constant min signed integer value.
 bool isMinSignedConstant(SDValue V);
 
+/// Returns true if \p V is a neutral element of Opc with Flags.
+/// When OperandNo is 0, it checks that V is a left identity. Otherwise, it
+/// checks that V is a right identity.
+bool isNeutralConstant(unsigned Opc, SDNodeFlags Flags, SDValue V,
+                       unsigned OperandNo);
+
 /// Return the non-bitcasted source operand of \p V if it exists.
 /// If \p V is not a bitcasted value, it is returned as-is.
 SDValue peekThroughBitcasts(SDValue V);

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 4e6d17bb81f9b..3b6f6c48b1840 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -10747,6 +10747,65 @@ bool llvm::isMinSignedConstant(SDValue V) {
   return Const != nullptr && Const->isMinSignedValue();
 }
 
+bool llvm::isNeutralConstant(unsigned Opcode, SDNodeFlags Flags, SDValue V,
+                             unsigned OperandNo) {
+  if (auto *Const = dyn_cast<ConstantSDNode>(V)) {
+    switch (Opcode) {
+    case ISD::ADD:
+    case ISD::OR:
+    case ISD::XOR:
+    case ISD::UMAX:
+      return Const->isZero();
+    case ISD::MUL:
+      return Const->isOne();
+    case ISD::AND:
+    case ISD::UMIN:
+      return Const->isAllOnes();
+    case ISD::SMAX:
+      return Const->isMinSignedValue();
+    case ISD::SMIN:
+      return Const->isMaxSignedValue();
+    case ISD::SUB:
+    case ISD::SHL:
+    case ISD::SRA:
+    case ISD::SRL:
+      return OperandNo == 1 && Const->isZero();
+    case ISD::UDIV:
+    case ISD::SDIV:
+      return OperandNo == 1 && Const->isOne();
+    }
+  } else if (auto *ConstFP = dyn_cast<ConstantFPSDNode>(V)) {
+    switch (Opcode) {
+    case ISD::FADD:
+      return ConstFP->isZero() &&
+             (Flags.hasNoSignedZeros() || ConstFP->isNegative());
+    case ISD::FSUB:
+      return OperandNo == 1 && ConstFP->isZero() &&
+             (Flags.hasNoSignedZeros() || !ConstFP->isNegative());
+    case ISD::FMUL:
+      return ConstFP->isExactlyValue(1.0);
+    case ISD::FDIV:
+      return OperandNo == 1 && ConstFP->isExactlyValue(1.0);
+    case ISD::FMINNUM:
+    case ISD::FMAXNUM: {
+      // Neutral element for fminnum is NaN, Inf or FLT_MAX, depending on FMF.
+      EVT VT = V.getValueType();
+      const fltSemantics &Semantics = SelectionDAG::EVTToAPFloatSemantics(VT);
+      APFloat NeutralAF = !Flags.hasNoNaNs()
+                              ? APFloat::getQNaN(Semantics)
+                              : !Flags.hasNoInfs()
+                                    ? APFloat::getInf(Semantics)
+                                    : APFloat::getLargest(Semantics);
+      if (Opcode == ISD::FMAXNUM)
+        NeutralAF.changeSign();
+
+      return ConstFP->isExactlyValue(NeutralAF);
+    }
+    }
+  }
+  return false;
+}
+
 SDValue llvm::peekThroughBitcasts(SDValue V) {
   while (V.getOpcode() == ISD::BITCAST)
     V = V.getOperand(0);

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 03b3b7d23ec26..6b124c8372a48 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -7714,17 +7714,10 @@ static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG) {
   if (!isOneConstant(ScalarV.getOperand(2)))
     return SDValue();
 
-  // TODO: Deal with value other than neutral element.
-  auto IsRVVNeutralElement = [Opc, &DAG](SDNode *N, SDValue V) {
-    if (Opc == ISD::FADD && N->getFlags().hasNoSignedZeros() &&
-        isNullFPConstant(V))
-      return true;
-    return DAG.getNeutralElement(Opc, SDLoc(V), V.getSimpleValueType(),
-                                 N->getFlags()) == V;
-  };
-
   // Check the scalar of ScalarV is neutral element
-  if (!IsRVVNeutralElement(N, ScalarV.getOperand(1)))
+  // TODO: Deal with value other than neutral element.
+  if (!isNeutralConstant(N->getOpcode(), N->getFlags(), ScalarV.getOperand(1),
+                         0))
     return SDValue();
 
   if (!ScalarV.hasOneUse())


        


More information about the llvm-commits mailing list