[llvm] 6795331 - [SelectionDAG] Teach SelectionDAG::FoldConstantArithmetic to handle SPLAT_VECTOR

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 7 10:12:51 PDT 2021


Author: Craig Topper
Date: 2021-04-07T10:03:33-07:00
New Revision: 67953311e2e370a9fcf77595d66d39c505565382

URL: https://github.com/llvm/llvm-project/commit/67953311e2e370a9fcf77595d66d39c505565382
DIFF: https://github.com/llvm/llvm-project/commit/67953311e2e370a9fcf77595d66d39c505565382.diff

LOG: [SelectionDAG] Teach SelectionDAG::FoldConstantArithmetic to handle SPLAT_VECTOR

This allows FoldConstantArithmetic to handle SPLAT_VECTOR in
addition to BUILD_VECTOR. This allows it to support scalable
vectors. I'm also allowing fixed length SPLAT_VECTOR which is
used by some targets, but I'm not familiar enough to write tests
for those targets.

I had to block this function from running on CONCAT_VECTORS to
avoid calling getNode for a CONCAT_VECTORS of 2 scalars.
This can happen because the 2 operand getNode calls this
function for any opcode. Previously we were protected because
CONCAT_VECTORs of BUILD_VECTOR is folded to a larger BUILD_VECTOR
before that call. But it's not always possible to fold a CONCAT_VECTORS
of SPLAT_VECTORs, and we don't even try.

This fixes PR49781 where DAG combine thought constant folding
should be possible, but FoldConstantArithmetic couldn't do it.

Reviewed By: david-arm

Differential Revision: https://reviews.llvm.org/D99682

Added: 
    llvm/test/CodeGen/AArch64/pr49781.ll

Modified: 
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/test/CodeGen/RISCV/rvv/vadd-sdnode-rv32.ll
    llvm/test/CodeGen/RISCV/rvv/vadd-sdnode-rv64.ll
    llvm/test/CodeGen/RISCV/rvv/vmulh-sdnode-rv32.ll
    llvm/test/CodeGen/RISCV/rvv/vsub-sdnode-rv32.ll
    llvm/test/CodeGen/RISCV/rvv/vsub-sdnode-rv64.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index cfe0340198a8..89371aa5cab9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -5056,7 +5056,10 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
   // If the opcode is a target-specific ISD node, there's nothing we can
   // do here and the operand rules may not line up with the below, so
   // bail early.
-  if (Opcode >= ISD::BUILTIN_OP_END)
+  // We can't create a scalar CONCAT_VECTORS so skip it. It will break
+  // for concats involving SPLAT_VECTOR. Concats of BUILD_VECTORS are handled by
+  // foldCONCAT_VECTORS in getNode before this is called.
+  if (Opcode >= ISD::BUILTIN_OP_END || Opcode == ISD::CONCAT_VECTORS)
     return SDValue();
 
   // For now, the array Ops should only contain two values.
@@ -5096,27 +5099,20 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
     if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(N2))
       return FoldSymbolOffset(Opcode, VT, GA, N1);
 
-  // TODO: All the folds below are performed lane-by-lane and assume a fixed
-  // vector width, however we should be able to do constant folds involving
-  // splat vector nodes too.
-  if (VT.isScalableVector())
-    return SDValue();
-
   // For fixed width vectors, extract each constant element and fold them
   // individually. Either input may be an undef value.
-  auto *BV1 = dyn_cast<BuildVectorSDNode>(N1);
-  if (!BV1 && !N1->isUndef())
+  bool IsBVOrSV1 = N1->getOpcode() == ISD::BUILD_VECTOR ||
+                   N1->getOpcode() == ISD::SPLAT_VECTOR;
+  if (!IsBVOrSV1 && !N1->isUndef())
     return SDValue();
-  auto *BV2 = dyn_cast<BuildVectorSDNode>(N2);
-  if (!BV2 && !N2->isUndef())
+  bool IsBVOrSV2 = N2->getOpcode() == ISD::BUILD_VECTOR ||
+                   N2->getOpcode() == ISD::SPLAT_VECTOR;
+  if (!IsBVOrSV2 && !N2->isUndef())
     return SDValue();
   // If both operands are undef, that's handled the same way as scalars.
-  if (!BV1 && !BV2)
+  if (!IsBVOrSV1 && !IsBVOrSV2)
     return SDValue();
 
-  assert((!BV1 || !BV2 || BV1->getNumOperands() == BV2->getNumOperands()) &&
-         "Vector binop with 
diff erent number of elements in operands?");
-
   EVT SVT = VT.getScalarType();
   EVT LegalSVT = SVT;
   if (NewNodesMustHaveLegalTypes && LegalSVT.isInteger()) {
@@ -5124,19 +5120,46 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
     if (LegalSVT.bitsLT(SVT))
       return SDValue();
   }
+
   SmallVector<SDValue, 4> Outputs;
-  unsigned NumOps = BV1 ? BV1->getNumOperands() : BV2->getNumOperands();
+  unsigned NumOps = 0;
+  if (IsBVOrSV1)
+    NumOps = std::max(NumOps, N1->getNumOperands());
+  if (IsBVOrSV2)
+    NumOps = std::max(NumOps, N2->getNumOperands());
+  assert(NumOps != 0 && "Expected non-zero operands");
+  // Scalable vectors should only be SPLAT_VECTOR or UNDEF here. We only need
+  // one iteration for that.
+  assert((!VT.isScalableVector() || NumOps == 1) &&
+         "Scalar vector should only have one scalar");
+
   for (unsigned I = 0; I != NumOps; ++I) {
-    SDValue V1 = BV1 ? BV1->getOperand(I) : getUNDEF(SVT);
-    SDValue V2 = BV2 ? BV2->getOperand(I) : getUNDEF(SVT);
+    // We can have a fixed length SPLAT_VECTOR and a BUILD_VECTOR so we need
+    // to use operand 0 of the SPLAT_VECTOR for each fixed element.
+    SDValue V1;
+    if (N1->getOpcode() == ISD::BUILD_VECTOR)
+      V1 = N1->getOperand(I);
+    else if (N1->getOpcode() == ISD::SPLAT_VECTOR)
+      V1 = N1->getOperand(0);
+    else
+      V1 = getUNDEF(SVT);
+
+    SDValue V2;
+    if (N2->getOpcode() == ISD::BUILD_VECTOR)
+      V2 = N2->getOperand(I);
+    else if (N2->getOpcode() == ISD::SPLAT_VECTOR)
+      V2 = N2->getOperand(0);
+    else
+      V2 = getUNDEF(SVT);
+
     if (SVT.isInteger()) {
-      if (V1->getValueType(0).bitsGT(SVT))
+      if (V1.getValueType().bitsGT(SVT))
         V1 = getNode(ISD::TRUNCATE, DL, SVT, V1);
-      if (V2->getValueType(0).bitsGT(SVT))
+      if (V2.getValueType().bitsGT(SVT))
         V2 = getNode(ISD::TRUNCATE, DL, SVT, V2);
     }
 
-    if (V1->getValueType(0) != SVT || V2->getValueType(0) != SVT)
+    if (V1.getValueType() != SVT || V2.getValueType() != SVT)
       return SDValue();
 
     // Fold one vector element.
@@ -5151,11 +5174,21 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
     Outputs.push_back(ScalarResult);
   }
 
-  assert(VT.getVectorNumElements() == Outputs.size() &&
-         "Vector size mismatch!");
+  if (N1->getOpcode() == ISD::BUILD_VECTOR ||
+      N2->getOpcode() == ISD::BUILD_VECTOR) {
+    assert(VT.getVectorNumElements() == Outputs.size() &&
+           "Vector size mismatch!");
+
+    // Build a big vector out of the scalar elements we generated.
+    return getBuildVector(VT, SDLoc(), Outputs);
+  }
+
+  assert((N1->getOpcode() == ISD::SPLAT_VECTOR ||
+          N2->getOpcode() == ISD::SPLAT_VECTOR) &&
+         "One operand should be a splat vector");
 
-  // Build a big vector out of the scalar elements we generated.
-  return getBuildVector(VT, SDLoc(), Outputs);
+  assert(Outputs.size() == 1 && "Vector size mismatch!");
+  return getSplatVector(VT, SDLoc(), Outputs[0]);
 }
 
 // TODO: Merge with FoldConstantArithmetic

diff  --git a/llvm/test/CodeGen/AArch64/pr49781.ll b/llvm/test/CodeGen/AArch64/pr49781.ll
new file mode 100644
index 000000000000..066feda599d2
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/pr49781.ll
@@ -0,0 +1,13 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=aarch64 -mattr=+sve | FileCheck %s
+
+define <vscale x 2 x i64> @foo(<vscale x 2 x i64> %a) {
+; CHECK-LABEL: foo:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sub z0.d, z0.d, #2 // =0x2
+; CHECK-NEXT:    ret
+ %idx = shufflevector <vscale x 2 x i64> insertelement (<vscale x 2 x i64> undef, i64 1, i32 0), <vscale x 2 x i64> zeroinitializer, <vscale x 2 x i32> zeroinitializer
+ %b = sub <vscale x 2 x i64> %a, %idx
+ %c = sub <vscale x 2 x i64> %b, %idx
+ ret <vscale x 2 x i64> %c
+}

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vadd-sdnode-rv32.ll b/llvm/test/CodeGen/RISCV/rvv/vadd-sdnode-rv32.ll
index abd5ccedfaa6..84f11909bd4f 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vadd-sdnode-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vadd-sdnode-rv32.ll
@@ -38,13 +38,11 @@ define <vscale x 1 x i8> @vadd_vx_nxv1i8_1(<vscale x 1 x i8> %va) {
 }
 
 ; Test constant adds to see if we can optimize them away for scalable vectors.
-; FIXME: We can't.
 define <vscale x 1 x i8> @vadd_ii_nxv1i8_1() {
 ; CHECK-LABEL: vadd_ii_nxv1i8_1:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetvli a0, zero, e8,mf8,ta,mu
-; CHECK-NEXT:    vmv.v.i v25, 2
-; CHECK-NEXT:    vadd.vi v8, v25, 3
+; CHECK-NEXT:    vmv.v.i v8, 5
 ; CHECK-NEXT:    ret
   %heada = insertelement <vscale x 1 x i8> undef, i8 2, i32 0
   %splata = shufflevector <vscale x 1 x i8> %heada, <vscale x 1 x i8> undef, <vscale x 1 x i32> zeroinitializer

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vadd-sdnode-rv64.ll b/llvm/test/CodeGen/RISCV/rvv/vadd-sdnode-rv64.ll
index a54103da10fd..cc31c771210e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vadd-sdnode-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vadd-sdnode-rv64.ll
@@ -37,6 +37,21 @@ define <vscale x 1 x i8> @vadd_vx_nxv1i8_1(<vscale x 1 x i8> %va) {
   ret <vscale x 1 x i8> %vc
 }
 
+; Test constant adds to see if we can optimize them away for scalable vectors.
+define <vscale x 1 x i8> @vadd_ii_nxv1i8_1() {
+; CHECK-LABEL: vadd_ii_nxv1i8_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e8,mf8,ta,mu
+; CHECK-NEXT:    vmv.v.i v8, 5
+; CHECK-NEXT:    ret
+  %heada = insertelement <vscale x 1 x i8> undef, i8 2, i32 0
+  %splata = shufflevector <vscale x 1 x i8> %heada, <vscale x 1 x i8> undef, <vscale x 1 x i32> zeroinitializer
+  %headb = insertelement <vscale x 1 x i8> undef, i8 3, i32 0
+  %splatb = shufflevector <vscale x 1 x i8> %headb, <vscale x 1 x i8> undef, <vscale x 1 x i32> zeroinitializer
+  %vc = add <vscale x 1 x i8> %splata, %splatb
+  ret <vscale x 1 x i8> %vc
+}
+
 define <vscale x 2 x i8> @vadd_vx_nxv2i8(<vscale x 2 x i8> %va, i8 signext %b) {
 ; CHECK-LABEL: vadd_vx_nxv2i8:
 ; CHECK:       # %bb.0:

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vmulh-sdnode-rv32.ll b/llvm/test/CodeGen/RISCV/rvv/vmulh-sdnode-rv32.ll
index 2798b67fa355..fbbae63de1c7 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vmulh-sdnode-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vmulh-sdnode-rv32.ll
@@ -11,13 +11,9 @@ define <vscale x 4 x i1> @srem_eq_fold_nxv4i8(<vscale x 4 x i8> %va) {
 ; CHECK-NEXT:    vmul.vx v25, v8, a0
 ; CHECK-NEXT:    addi a0, zero, 42
 ; CHECK-NEXT:    vadd.vx v25, v25, a0
-; CHECK-NEXT:    vmv.v.i v26, 1
-; CHECK-NEXT:    vrsub.vi v27, v26, 0
-; CHECK-NEXT:    vand.vi v27, v27, 7
-; CHECK-NEXT:    vsll.vv v27, v25, v27
-; CHECK-NEXT:    vand.vi v26, v26, 7
-; CHECK-NEXT:    vsrl.vv v25, v25, v26
-; CHECK-NEXT:    vor.vv v25, v25, v27
+; CHECK-NEXT:    vsll.vi v26, v25, 7
+; CHECK-NEXT:    vsrl.vi v25, v25, 1
+; CHECK-NEXT:    vor.vv v25, v25, v26
 ; CHECK-NEXT:    vmsleu.vx v0, v25, a0
 ; CHECK-NEXT:    ret
   %head_six = insertelement <vscale x 4 x i8> undef, i8 6, i32 0

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vsub-sdnode-rv32.ll b/llvm/test/CodeGen/RISCV/rvv/vsub-sdnode-rv32.ll
index e2317ebf7af1..fb5ddf05f357 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vsub-sdnode-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vsub-sdnode-rv32.ll
@@ -36,6 +36,21 @@ define <vscale x 1 x i8> @vsub_vx_nxv1i8_0(<vscale x 1 x i8> %va) {
   ret <vscale x 1 x i8> %vc
 }
 
+; Test constant subs to see if we can optimize them away for scalable vectors.
+define <vscale x 1 x i8> @vsub_ii_nxv1i8_1() {
+; CHECK-LABEL: vsub_ii_nxv1i8_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e8,mf8,ta,mu
+; CHECK-NEXT:    vmv.v.i v8, -1
+; CHECK-NEXT:    ret
+  %heada = insertelement <vscale x 1 x i8> undef, i8 2, i32 0
+  %splata = shufflevector <vscale x 1 x i8> %heada, <vscale x 1 x i8> undef, <vscale x 1 x i32> zeroinitializer
+  %headb = insertelement <vscale x 1 x i8> undef, i8 3, i32 0
+  %splatb = shufflevector <vscale x 1 x i8> %headb, <vscale x 1 x i8> undef, <vscale x 1 x i32> zeroinitializer
+  %vc = sub <vscale x 1 x i8> %splata, %splatb
+  ret <vscale x 1 x i8> %vc
+}
+
 define <vscale x 2 x i8> @vsub_vv_nxv2i8(<vscale x 2 x i8> %va, <vscale x 2 x i8> %vb) {
 ; CHECK-LABEL: vsub_vv_nxv2i8:
 ; CHECK:       # %bb.0:

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vsub-sdnode-rv64.ll b/llvm/test/CodeGen/RISCV/rvv/vsub-sdnode-rv64.ll
index 090bae1909da..db8702f1360e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vsub-sdnode-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vsub-sdnode-rv64.ll
@@ -36,6 +36,21 @@ define <vscale x 1 x i8> @vsub_vx_nxv1i8_0(<vscale x 1 x i8> %va) {
   ret <vscale x 1 x i8> %vc
 }
 
+; Test constant subs to see if we can optimize them away for scalable vectors.
+define <vscale x 1 x i8> @vsub_ii_nxv1i8_1() {
+; CHECK-LABEL: vsub_ii_nxv1i8_1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e8,mf8,ta,mu
+; CHECK-NEXT:    vmv.v.i v8, -1
+; CHECK-NEXT:    ret
+  %heada = insertelement <vscale x 1 x i8> undef, i8 2, i32 0
+  %splata = shufflevector <vscale x 1 x i8> %heada, <vscale x 1 x i8> undef, <vscale x 1 x i32> zeroinitializer
+  %headb = insertelement <vscale x 1 x i8> undef, i8 3, i32 0
+  %splatb = shufflevector <vscale x 1 x i8> %headb, <vscale x 1 x i8> undef, <vscale x 1 x i32> zeroinitializer
+  %vc = sub <vscale x 1 x i8> %splata, %splatb
+  ret <vscale x 1 x i8> %vc
+}
+
 define <vscale x 2 x i8> @vsub_vv_nxv2i8(<vscale x 2 x i8> %va, <vscale x 2 x i8> %vb) {
 ; CHECK-LABEL: vsub_vv_nxv2i8:
 ; CHECK:       # %bb.0:


        


More information about the llvm-commits mailing list