[llvm] b8f765a - [AArch64][SVE] Add support for trunc to <vscale x N x i1>.

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 20 13:11:43 PDT 2020


Author: Eli Friedman
Date: 2020-07-20T13:11:02-07:00
New Revision: b8f765a1e17f8d212ab1cd8f630d35adc7495556

URL: https://github.com/llvm/llvm-project/commit/b8f765a1e17f8d212ab1cd8f630d35adc7495556
DIFF: https://github.com/llvm/llvm-project/commit/b8f765a1e17f8d212ab1cd8f630d35adc7495556.diff

LOG: [AArch64][SVE] Add support for trunc to <vscale x N x i1>.

This isn't a natively supported operation, so convert it to a
mask+compare.

In addition to the operation itself, fix up some surrounding stuff to
make the testcase work: we need concat_vectors on i1 vectors, we need
legalization of i1 vector truncates, and we need to fix up all the
relevant uses of getVectorNumElements().

Differential Revision: https://reviews.llvm.org/D83811

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/test/CodeGen/AArch64/sve-trunc.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index f14b3dba4f31..a026d3960026 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -11372,9 +11372,10 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
       // Stop if more than one members are non-undef.
       if (NumDefs > 1)
         break;
+
       VTs.push_back(EVT::getVectorVT(*DAG.getContext(),
                                      VT.getVectorElementType(),
-                                     X.getValueType().getVectorNumElements()));
+                                     X.getValueType().getVectorElementCount()));
     }
 
     if (NumDefs == 0)
@@ -18795,6 +18796,11 @@ static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) {
 static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
   EVT VT = N->getValueType(0);
   EVT OpVT = N->getOperand(0).getValueType();
+
+  // We currently can't generate an appropriate shuffle for a scalable vector.
+  if (VT.isScalableVector())
+    return SDValue();
+
   int NumElts = VT.getVectorNumElements();
   int NumOpElts = OpVT.getVectorNumElements();
 
@@ -19055,11 +19061,14 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
     return V;
 
   // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR
-  // nodes often generate nop CONCAT_VECTOR nodes.
-  // Scan the CONCAT_VECTOR operands and look for a CONCAT operations that
-  // place the incoming vectors at the exact same location.
+  // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR
+  // operands and look for a CONCAT operations that place the incoming vectors
+  // at the exact same location.
+  //
+  // For scalable vectors, EXTRACT_SUBVECTOR indexes are implicitly scaled.
   SDValue SingleSource = SDValue();
-  unsigned PartNumElem = N->getOperand(0).getValueType().getVectorNumElements();
+  unsigned PartNumElem =
+      N->getOperand(0).getValueType().getVectorMinNumElements();
 
   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
     SDValue Op = N->getOperand(i);

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 414ba25ffd5f..b1ec3050e201 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -2151,7 +2151,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_UnaryOp(SDNode *N) {
   EVT InVT = Lo.getValueType();
 
   EVT OutVT = EVT::getVectorVT(*DAG.getContext(), ResVT.getVectorElementType(),
-                               InVT.getVectorNumElements());
+                               InVT.getVectorElementCount());
 
   if (N->isStrictFPOpcode()) {
     Lo = DAG.getNode(N->getOpcode(), dl, { OutVT, MVT::Other }, 
@@ -2559,13 +2559,9 @@ SDValue DAGTypeLegalizer::SplitVecOp_TruncateHelper(SDNode *N) {
   SDValue InVec = N->getOperand(OpNo);
   EVT InVT = InVec->getValueType(0);
   EVT OutVT = N->getValueType(0);
-  unsigned NumElements = OutVT.getVectorNumElements();
+  ElementCount NumElements = OutVT.getVectorElementCount();
   bool IsFloat = OutVT.isFloatingPoint();
 
-  // Widening should have already made sure this is a power-two vector
-  // if we're trying to split it at all. assert() that's true, just in case.
-  assert(!(NumElements & 1) && "Splitting vector, but not in half!");
-
   unsigned InElementSize = InVT.getScalarSizeInBits();
   unsigned OutElementSize = OutVT.getScalarSizeInBits();
 
@@ -2595,6 +2591,9 @@ SDValue DAGTypeLegalizer::SplitVecOp_TruncateHelper(SDNode *N) {
   GetSplitVector(InVec, InLoVec, InHiVec);
 
   // Truncate them to 1/2 the element size.
+  //
+  // This assumes the number of elements is a power of two; any vector that
+  // isn't should be widened, not split.
   EVT HalfElementVT = IsFloat ?
     EVT::getFloatingPointVT(InElementSize/2) :
     EVT::getIntegerVT(*DAG.getContext(), InElementSize/2);

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c13ab9412b5b..2c49462d277f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -931,8 +931,11 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
         setOperationAction(ISD::SHL, VT, Custom);
         setOperationAction(ISD::SRL, VT, Custom);
         setOperationAction(ISD::SRA, VT, Custom);
-        if (VT.getScalarType() == MVT::i1)
+        if (VT.getScalarType() == MVT::i1) {
           setOperationAction(ISD::SETCC, VT, Custom);
+          setOperationAction(ISD::TRUNCATE, VT, Custom);
+          setOperationAction(ISD::CONCAT_VECTORS, VT, Legal);
+        }
       }
     }
 
@@ -8874,6 +8877,16 @@ SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
                                              SelectionDAG &DAG) const {
   EVT VT = Op.getValueType();
 
+  if (VT.getScalarType() == MVT::i1) {
+    // Lower i1 truncate to `(x & 1) != 0`.
+    SDLoc dl(Op);
+    EVT OpVT = Op.getOperand(0).getValueType();
+    SDValue Zero = DAG.getConstant(0, dl, OpVT);
+    SDValue One = DAG.getConstant(1, dl, OpVT);
+    SDValue And = DAG.getNode(ISD::AND, dl, OpVT, Op.getOperand(0), One);
+    return DAG.getSetCC(dl, VT, And, Zero, ISD::SETNE);
+  }
+
   if (!VT.isVector() || VT.isScalableVector())
     return Op;
 

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 1d7b774f2ee4..dc501a9536b9 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -1135,6 +1135,14 @@ multiclass sve_prefetch<SDPatternOperator prefetch, ValueType PredTy, Instructio
   def : Pat<(nxv8i1 (extract_subvector (nxv16i1 PPR:$Ps), (i64 8))),
             (ZIP2_PPP_B PPR:$Ps, (PFALSE))>;
 
+  // Concatenate two predicates.
+  def : Pat<(nxv4i1 (concat_vectors nxv2i1:$p1, nxv2i1:$p2)),
+            (UZP1_PPP_S $p1, $p2)>;
+  def : Pat<(nxv8i1 (concat_vectors nxv4i1:$p1, nxv4i1:$p2)),
+            (UZP1_PPP_H $p1, $p2)>;
+  def : Pat<(nxv16i1 (concat_vectors nxv8i1:$p1, nxv8i1:$p2)),
+            (UZP1_PPP_B $p1, $p2)>;
+
   defm CMPHS_PPzZZ : sve_int_cmp_0<0b000, "cmphs", SETUGE, SETULE>;
   defm CMPHI_PPzZZ : sve_int_cmp_0<0b001, "cmphi", SETUGT, SETULT>;
   defm CMPGE_PPzZZ : sve_int_cmp_0<0b100, "cmpge", SETGE, SETLE>;

diff  --git a/llvm/test/CodeGen/AArch64/sve-trunc.ll b/llvm/test/CodeGen/AArch64/sve-trunc.ll
index 876003a3962c..3743301cfa9b 100644
--- a/llvm/test/CodeGen/AArch64/sve-trunc.ll
+++ b/llvm/test/CodeGen/AArch64/sve-trunc.ll
@@ -59,3 +59,123 @@ entry:
   %out = trunc <vscale x 2 x i64> %in to <vscale x 2 x i32>
   ret <vscale x 2 x i32> %out
 }
+
+; Truncating to i1 requires convert it to a cmp
+
+define <vscale x 2 x i1> @trunc_i64toi1(<vscale x 2 x i64> %in) {
+; CHECK-LABEL: trunc_i64toi1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    and z0.d, z0.d, #0x1
+; CHECK-NEXT:    cmpne p0.d, p0/z, z0.d, #0
+; CHECK-NEXT:    ret
+entry:
+  %out = trunc <vscale x 2 x i64> %in to <vscale x 2 x i1>
+  ret <vscale x 2 x i1> %out
+}
+
+define <vscale x 4 x i1> @trunc_i64toi1_split(<vscale x 4 x i64> %in) {
+; CHECK-LABEL: trunc_i64toi1_split:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    and z1.d, z1.d, #0x1
+; CHECK-NEXT:    and z0.d, z0.d, #0x1
+; CHECK-NEXT:    cmpne p1.d, p0/z, z1.d, #0
+; CHECK-NEXT:    cmpne p0.d, p0/z, z0.d, #0
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p1.s
+; CHECK-NEXT:    ret
+entry:
+  %out = trunc <vscale x 4 x i64> %in to <vscale x 4 x i1>
+  ret <vscale x 4 x i1> %out
+}
+
+define <vscale x 8 x i1> @trunc_i64toi1_split2(<vscale x 8 x i64> %in) {
+; CHECK-LABEL: trunc_i64toi1_split2:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    and z3.d, z3.d, #0x1
+; CHECK-NEXT:    and z2.d, z2.d, #0x1
+; CHECK-NEXT:    and z1.d, z1.d, #0x1
+; CHECK-NEXT:    and z0.d, z0.d, #0x1
+; CHECK-NEXT:    cmpne p1.d, p0/z, z3.d, #0
+; CHECK-NEXT:    cmpne p2.d, p0/z, z2.d, #0
+; CHECK-NEXT:    uzp1 p1.s, p2.s, p1.s
+; CHECK-NEXT:    cmpne p2.d, p0/z, z1.d, #0
+; CHECK-NEXT:    cmpne p0.d, p0/z, z0.d, #0
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p2.s
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p1.h
+; CHECK-NEXT:    ret
+entry:
+  %out = trunc <vscale x 8 x i64> %in to <vscale x 8 x i1>
+  ret <vscale x 8 x i1> %out
+}
+
+define <vscale x 16 x i1> @trunc_i64toi1_split3(<vscale x 16 x i64> %in) {
+; CHECK-LABEL: trunc_i64toi1_split3:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    and z7.d, z7.d, #0x1
+; CHECK-NEXT:    and z6.d, z6.d, #0x1
+; CHECK-NEXT:    and z5.d, z5.d, #0x1
+; CHECK-NEXT:    and z4.d, z4.d, #0x1
+; CHECK-NEXT:    and z3.d, z3.d, #0x1
+; CHECK-NEXT:    and z2.d, z2.d, #0x1
+; CHECK-NEXT:    cmpne p1.d, p0/z, z7.d, #0
+; CHECK-NEXT:    cmpne p2.d, p0/z, z6.d, #0
+; CHECK-NEXT:    cmpne p3.d, p0/z, z5.d, #0
+; CHECK-NEXT:    cmpne p4.d, p0/z, z4.d, #0
+; CHECK-NEXT:    and z1.d, z1.d, #0x1
+; CHECK-NEXT:    and z0.d, z0.d, #0x1
+; CHECK-NEXT:    uzp1 p1.s, p2.s, p1.s
+; CHECK-NEXT:    cmpne p2.d, p0/z, z3.d, #0
+; CHECK-NEXT:    uzp1 p3.s, p4.s, p3.s
+; CHECK-NEXT:    cmpne p4.d, p0/z, z2.d, #0
+; CHECK-NEXT:    uzp1 p2.s, p4.s, p2.s
+; CHECK-NEXT:    cmpne p4.d, p0/z, z1.d, #0
+; CHECK-NEXT:    cmpne p0.d, p0/z, z0.d, #0
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p4.s
+; CHECK-NEXT:    uzp1 p1.h, p3.h, p1.h
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p2.h
+; CHECK-NEXT:    uzp1 p0.b, p0.b, p1.b
+; CHECK-NEXT:    ret
+entry:
+  %out = trunc <vscale x 16 x i64> %in to <vscale x 16 x i1>
+  ret <vscale x 16 x i1> %out
+}
+
+
+define <vscale x 4 x i1> @trunc_i32toi1(<vscale x 4 x i32> %in) {
+; CHECK-LABEL: trunc_i32toi1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    and z0.s, z0.s, #0x1
+; CHECK-NEXT:    cmpne p0.s, p0/z, z0.s, #0
+; CHECK-NEXT:    ret
+entry:
+  %out = trunc <vscale x 4 x i32> %in to <vscale x 4 x i1>
+  ret <vscale x 4 x i1> %out
+}
+
+define <vscale x 8 x i1> @trunc_i16toi1(<vscale x 8 x i16> %in) {
+; CHECK-LABEL: trunc_i16toi1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    and z0.h, z0.h, #0x1
+; CHECK-NEXT:    cmpne p0.h, p0/z, z0.h, #0
+; CHECK-NEXT:    ret
+entry:
+  %out = trunc <vscale x 8 x i16> %in to <vscale x 8 x i1>
+  ret <vscale x 8 x i1> %out
+}
+
+define <vscale x 16 x i1> @trunc_i8toi1(<vscale x 16 x i8> %in) {
+; CHECK-LABEL: trunc_i8toi1:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    and z0.b, z0.b, #0x1
+; CHECK-NEXT:    cmpne p0.b, p0/z, z0.b, #0
+; CHECK-NEXT:    ret
+entry:
+  %out = trunc <vscale x 16 x i8> %in to <vscale x 16 x i1>
+  ret <vscale x 16 x i1> %out
+}


        


More information about the llvm-commits mailing list