[llvm] [SelectionDAG][RISCV] Move VP_REDUCE* legalization to LegalizeDAG.cpp. (PR #90522)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 29 21:48:57 PDT 2024


https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/90522

>From 96cc3ac6bd6e8d59d04c64137a4732aa5fac1ce8 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Mon, 29 Apr 2024 13:48:21 -0700
Subject: [PATCH 1/2] [SelectionDAG][RISCV] Move VP_REDUCE* legalization to
 LegalizeDAG.cpp.

LegalizeVectorType is responsible for legalizing nodes that perform an
operation on each element may need to scalarize.

This is not true for nodes like VP_REDUCE.*, BUILD_VECTOR,
SHUFFLE_VECTOR, EXTRACT_SUBVECTOR, etc.

This patch drops any nodes with a scalar result from LegalizeVectorOps
and handles them in LegalizeDAG instead.

This required moving the reduction promotion to LegalizeDAG. I have
removed the support integer promotion as it was incorrect for integer
min/max reductions. Since it was untested, it was best to assert on it
until it was really needed.

There are a couple regressions that can be fixed with a small DAG combine
which I will do as a follow up.
---
 llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 64 +++++++++++++++-
 .../SelectionDAG/LegalizeVectorOps.cpp        | 73 ++-----------------
 .../rvv/fixed-vectors-reduction-int-vp.ll     | 26 ++++---
 .../CodeGen/RISCV/rvv/vreductions-int-vp.ll   |  7 +-
 4 files changed, 87 insertions(+), 83 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 46e54b5366d66a..9dd40531abb005 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -180,6 +180,13 @@ class SelectionDAGLegalize {
                              SmallVectorImpl<SDValue> &Results);
   SDValue PromoteLegalFP_TO_INT_SAT(SDNode *Node, const SDLoc &dl);
 
+  /// Implements vector reduce operation promotion.
+  ///
+  /// All vector operands are promoted to a vector type with larger element
+  /// type, and the start value is promoted to a larger scalar type. Then the
+  /// result is truncated back to the original scalar type.
+  void PromoteReduction(SDNode *Node, SmallVectorImpl<SDValue> &Results);
+
   SDValue ExpandPARITY(SDValue Op, const SDLoc &dl);
 
   SDValue ExpandExtractFromVectorThroughStack(SDValue Op);
@@ -2979,6 +2986,49 @@ SDValue SelectionDAGLegalize::ExpandPARITY(SDValue Op, const SDLoc &dl) {
   return DAG.getNode(ISD::AND, dl, VT, Result, DAG.getConstant(1, dl, VT));
 }
 
+void SelectionDAGLegalize::PromoteReduction(SDNode *Node,
+                                            SmallVectorImpl<SDValue> &Results) {
+  MVT VecVT = Node->getOperand(1).getSimpleValueType();
+  MVT NewVecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), VecVT);
+  MVT ScalarVT = Node->getSimpleValueType(0);
+  MVT NewScalarVT = NewVecVT.getVectorElementType();
+
+  SDLoc DL(Node);
+  SmallVector<SDValue, 4> Operands(Node->getNumOperands());
+
+  // promote the initial value.
+  // FIXME: Support integer.
+  assert(Node->getOperand(0).getValueType().isFloatingPoint() &&
+         "Only FP promotion is supported");
+  Operands[0] =
+      DAG.getNode(ISD::FP_EXTEND, DL, NewScalarVT, Node->getOperand(0));
+
+  for (unsigned j = 1; j != Node->getNumOperands(); ++j)
+    if (Node->getOperand(j).getValueType().isVector() &&
+        !(ISD::isVPOpcode(Node->getOpcode()) &&
+          ISD::getVPMaskIdx(Node->getOpcode()) == j)) { // Skip mask operand.
+      // promote the vector operand.
+      // FIXME: Support integer.
+      assert(Node->getOperand(j).getValueType().isFloatingPoint() &&
+             "Only FP promotion is supported");
+      Operands[j] =
+          DAG.getNode(ISD::FP_EXTEND, DL, NewVecVT, Node->getOperand(j));
+    } else {
+      Operands[j] = Node->getOperand(j); // Skip VL operand.
+    }
+
+  SDValue Res = DAG.getNode(Node->getOpcode(), DL, NewScalarVT, Operands,
+                            Node->getFlags());
+
+  if (ScalarVT.isFloatingPoint())
+    Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res,
+                      DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
+  else
+    Res = DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, Res);
+
+  Results.push_back(Res);
+}
+
 bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
   LLVM_DEBUG(dbgs() << "Trying to expand node\n");
   SmallVector<SDValue, 8> Results;
@@ -4955,7 +5005,12 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
   if (Node->getOpcode() == ISD::STRICT_UINT_TO_FP ||
       Node->getOpcode() == ISD::STRICT_SINT_TO_FP ||
       Node->getOpcode() == ISD::STRICT_FSETCC ||
-      Node->getOpcode() == ISD::STRICT_FSETCCS)
+      Node->getOpcode() == ISD::STRICT_FSETCCS ||
+      Node->getOpcode() == ISD::VP_REDUCE_FADD ||
+      Node->getOpcode() == ISD::VP_REDUCE_FMUL ||
+      Node->getOpcode() == ISD::VP_REDUCE_FMAX ||
+      Node->getOpcode() == ISD::VP_REDUCE_FMIN ||
+      Node->getOpcode() == ISD::VP_REDUCE_SEQ_FADD)
     OVT = Node->getOperand(1).getSimpleValueType();
   if (Node->getOpcode() == ISD::BR_CC ||
       Node->getOpcode() == ISD::SELECT_CC)
@@ -5613,6 +5668,13 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
                     DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)));
     break;
   }
+  case ISD::VP_REDUCE_FADD:
+  case ISD::VP_REDUCE_FMUL:
+  case ISD::VP_REDUCE_FMAX:
+  case ISD::VP_REDUCE_FMIN:
+  case ISD::VP_REDUCE_SEQ_FADD:
+    PromoteReduction(Node, Results);
+    break;
   }
 
   // Replace the original node with the legalized result.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 8f87ee8e09393a..423df9ae6b2a55 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -176,13 +176,6 @@ class VectorLegalizer {
   /// truncated back to the original type.
   void PromoteFP_TO_INT(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
-  /// Implements vector reduce operation promotion.
-  ///
-  /// All vector operands are promoted to a vector type with larger element
-  /// type, and the start value is promoted to a larger scalar type. Then the
-  /// result is truncated back to the original scalar type.
-  void PromoteReduction(SDNode *Node, SmallVectorImpl<SDValue> &Results);
-
   /// Implements vector setcc operation promotion.
   ///
   /// All vector operands are promoted to a vector type with larger element
@@ -510,6 +503,11 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
       if (Action != TargetLowering::Legal)                                     \
         break;                                                                 \
     }                                                                          \
+    /* Defer non-vector results to LegalizeDAG. */                             \
+    if (!Node->getValueType(0).isVector()) {                                   \
+      Action = TargetLowering::Legal;                                          \
+      break;                                                                   \
+    }                                                                          \
     Action = TLI.getOperationAction(Node->getOpcode(), LegalizeVT);            \
   } break;
 #include "llvm/IR/VPIntrinsics.def"
@@ -580,50 +578,6 @@ bool VectorLegalizer::LowerOperationWrapper(SDNode *Node,
   return true;
 }
 
-void VectorLegalizer::PromoteReduction(SDNode *Node,
-                                       SmallVectorImpl<SDValue> &Results) {
-  MVT VecVT = Node->getOperand(1).getSimpleValueType();
-  MVT NewVecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), VecVT);
-  MVT ScalarVT = Node->getSimpleValueType(0);
-  MVT NewScalarVT = NewVecVT.getVectorElementType();
-
-  SDLoc DL(Node);
-  SmallVector<SDValue, 4> Operands(Node->getNumOperands());
-
-  // promote the initial value.
-  if (Node->getOperand(0).getValueType().isFloatingPoint())
-    Operands[0] =
-        DAG.getNode(ISD::FP_EXTEND, DL, NewScalarVT, Node->getOperand(0));
-  else
-    Operands[0] =
-        DAG.getNode(ISD::ANY_EXTEND, DL, NewScalarVT, Node->getOperand(0));
-
-  for (unsigned j = 1; j != Node->getNumOperands(); ++j)
-    if (Node->getOperand(j).getValueType().isVector() &&
-        !(ISD::isVPOpcode(Node->getOpcode()) &&
-          ISD::getVPMaskIdx(Node->getOpcode()) == j)) // Skip mask operand.
-      // promote the vector operand.
-      if (Node->getOperand(j).getValueType().isFloatingPoint())
-        Operands[j] =
-            DAG.getNode(ISD::FP_EXTEND, DL, NewVecVT, Node->getOperand(j));
-      else
-        Operands[j] =
-            DAG.getNode(ISD::ANY_EXTEND, DL, NewVecVT, Node->getOperand(j));
-    else
-      Operands[j] = Node->getOperand(j); // Skip VL operand.
-
-  SDValue Res = DAG.getNode(Node->getOpcode(), DL, NewScalarVT, Operands,
-                            Node->getFlags());
-
-  if (ScalarVT.isFloatingPoint())
-    Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res,
-                      DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
-  else
-    Res = DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, Res);
-
-  Results.push_back(Res);
-}
-
 void VectorLegalizer::PromoteSETCC(SDNode *Node,
                                    SmallVectorImpl<SDValue> &Results) {
   MVT VecVT = Node->getOperand(0).getSimpleValueType();
@@ -708,23 +662,6 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
     // Promote the operation by extending the operand.
     PromoteFP_TO_INT(Node, Results);
     return;
-  case ISD::VP_REDUCE_ADD:
-  case ISD::VP_REDUCE_MUL:
-  case ISD::VP_REDUCE_AND:
-  case ISD::VP_REDUCE_OR:
-  case ISD::VP_REDUCE_XOR:
-  case ISD::VP_REDUCE_SMAX:
-  case ISD::VP_REDUCE_SMIN:
-  case ISD::VP_REDUCE_UMAX:
-  case ISD::VP_REDUCE_UMIN:
-  case ISD::VP_REDUCE_FADD:
-  case ISD::VP_REDUCE_FMUL:
-  case ISD::VP_REDUCE_FMAX:
-  case ISD::VP_REDUCE_FMIN:
-  case ISD::VP_REDUCE_SEQ_FADD:
-    // Promote the operation by extending the operand.
-    PromoteReduction(Node, Results);
-    return;
   case ISD::VP_SETCC:
   case ISD::SETCC:
     // Promote the operation by extending the operand.
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
index 02a989a9699606..b874a4477f5d17 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
@@ -802,25 +802,27 @@ define signext i32 @vpreduce_xor_v64i32(i32 signext %s, <64 x i32> %v, <64 x i1>
 ; CHECK-LABEL: vpreduce_xor_v64i32:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, ta, ma
-; CHECK-NEXT:    li a3, 32
 ; CHECK-NEXT:    vslidedown.vi v24, v0, 4
-; CHECK-NEXT:    mv a2, a1
-; CHECK-NEXT:    bltu a1, a3, .LBB49_2
+; CHECK-NEXT:    addi a2, a1, -32
+; CHECK-NEXT:    sltu a3, a1, a2
+; CHECK-NEXT:    addi a3, a3, -1
+; CHECK-NEXT:    li a4, 32
+; CHECK-NEXT:    and a2, a3, a2
+; CHECK-NEXT:    bltu a1, a4, .LBB49_2
 ; CHECK-NEXT:  # %bb.1:
-; CHECK-NEXT:    li a2, 32
+; CHECK-NEXT:    li a1, 32
 ; CHECK-NEXT:  .LBB49_2:
 ; CHECK-NEXT:    vsetvli zero, zero, e32, m2, ta, ma
 ; CHECK-NEXT:    vmv.s.x v25, a0
-; CHECK-NEXT:    vsetvli zero, a2, e32, m8, ta, ma
+; CHECK-NEXT:    vsetvli zero, a1, e32, m8, ta, ma
 ; CHECK-NEXT:    vredxor.vs v25, v8, v25, v0.t
-; CHECK-NEXT:    addi a0, a1, -32
-; CHECK-NEXT:    sltu a1, a1, a0
-; CHECK-NEXT:    addi a1, a1, -1
-; CHECK-NEXT:    and a0, a1, a0
-; CHECK-NEXT:    vsetvli zero, a0, e32, m8, ta, ma
-; CHECK-NEXT:    vmv1r.v v0, v24
-; CHECK-NEXT:    vredxor.vs v25, v16, v25, v0.t
 ; CHECK-NEXT:    vmv.x.s a0, v25
+; CHECK-NEXT:    vsetivli zero, 1, e32, m8, ta, ma
+; CHECK-NEXT:    vmv.s.x v8, a0
+; CHECK-NEXT:    vsetvli zero, a2, e32, m8, ta, ma
+; CHECK-NEXT:    vmv1r.v v0, v24
+; CHECK-NEXT:    vredxor.vs v8, v16, v8, v0.t
+; CHECK-NEXT:    vmv.x.s a0, v8
 ; CHECK-NEXT:    ret
   %r = call i32 @llvm.vp.reduce.xor.v64i32(i32 %s, <64 x i32> %v, <64 x i1> %m, i32 %evl)
   ret i32 %r
diff --git a/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll
index 7bcf37b1af3c8f..95b64cb662a614 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll
@@ -1115,10 +1115,13 @@ define signext i32 @vpreduce_umax_nxv32i32(i32 signext %s, <vscale x 32 x i32> %
 ; CHECK-NEXT:    vmv.s.x v25, a0
 ; CHECK-NEXT:    vsetvli zero, a1, e32, m8, ta, ma
 ; CHECK-NEXT:    vredmaxu.vs v25, v8, v25, v0.t
+; CHECK-NEXT:    vmv.x.s a0, v25
+; CHECK-NEXT:    vsetivli zero, 1, e32, m8, ta, ma
+; CHECK-NEXT:    vmv.s.x v8, a0
 ; CHECK-NEXT:    vsetvli zero, a2, e32, m8, ta, ma
 ; CHECK-NEXT:    vmv1r.v v0, v24
-; CHECK-NEXT:    vredmaxu.vs v25, v16, v25, v0.t
-; CHECK-NEXT:    vmv.x.s a0, v25
+; CHECK-NEXT:    vredmaxu.vs v8, v16, v8, v0.t
+; CHECK-NEXT:    vmv.x.s a0, v8
 ; CHECK-NEXT:    ret
   %r = call i32 @llvm.vp.reduce.umax.nxv32i32(i32 %s, <vscale x 32 x i32> %v, <vscale x 32 x i1> %m, i32 %evl)
   ret i32 %r

>From 9ad0edc888af3c2db46933a000a1995eaba6cbb8 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Mon, 29 Apr 2024 21:45:22 -0700
Subject: [PATCH 2/2] fixup! remove dead code

---
 llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 9dd40531abb005..398b5fee990b5d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3020,11 +3020,9 @@ void SelectionDAGLegalize::PromoteReduction(SDNode *Node,
   SDValue Res = DAG.getNode(Node->getOpcode(), DL, NewScalarVT, Operands,
                             Node->getFlags());
 
-  if (ScalarVT.isFloatingPoint())
-    Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res,
-                      DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
-  else
-    Res = DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, Res);
+  assert(ScalarVT.isFloatingPoint() && "Only FP promotion is supported");
+  Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res,
+                    DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
 
   Results.push_back(Res);
 }



More information about the llvm-commits mailing list