[llvm] c9da81d - [AArch64][SVE] Implement missing lowering for extract_subvector for predicates.

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 27 03:01:20 PST 2022


Author: Sander de Smalen
Date: 2022-01-27T11:01:11Z
New Revision: c9da81d99760eebbe5bb4ab1188962a7edb5e780

URL: https://github.com/llvm/llvm-project/commit/c9da81d99760eebbe5bb4ab1188962a7edb5e780
DIFF: https://github.com/llvm/llvm-project/commit/c9da81d99760eebbe5bb4ab1188962a7edb5e780.diff

LOG: [AArch64][SVE] Implement missing lowering for extract_subvector for predicates.

Reviewed By: efriedma

Differential Revision: https://reviews.llvm.org/D118057

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/sve-insert-vector.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b7602fd829a3..676ee1b18914 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1248,6 +1248,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::SELECT_CC, VT, Expand);
       setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
       setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
+      setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
 
       // There are no legal MVT::nxv16f## based types.
       if (VT != MVT::nxv16i1) {
@@ -11038,6 +11039,28 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
     if (!isTypeLegal(VT))
       return SDValue();
 
+    // Break down insert_subvector into simpler parts.
+    if (VT.getVectorElementType() == MVT::i1) {
+      unsigned NumElts = VT.getVectorMinNumElements();
+      EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
+
+      SDValue Lo, Hi;
+      Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, Vec0,
+                       DAG.getVectorIdxConstant(0, DL));
+      Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, Vec0,
+                       DAG.getVectorIdxConstant(NumElts / 2, DL));
+      if (Idx < (NumElts / 2)) {
+        SDValue NewLo = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, HalfVT, Lo, Vec1,
+                                    DAG.getVectorIdxConstant(Idx, DL));
+        return DAG.getNode(AArch64ISD::UZP1, DL, VT, NewLo, Hi);
+      } else {
+        SDValue NewHi =
+            DAG.getNode(ISD::INSERT_SUBVECTOR, DL, HalfVT, Hi, Vec1,
+                        DAG.getVectorIdxConstant(Idx - (NumElts / 2), DL));
+        return DAG.getNode(AArch64ISD::UZP1, DL, VT, Lo, NewHi);
+      }
+    }
+
     // Ensure the subvector is half the size of the main vector.
     if (VT.getVectorElementCount() != (InVT.getVectorElementCount() * 2))
       return SDValue();
@@ -12961,7 +12984,7 @@ bool AArch64TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT,
   if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT))
     return false;
 
-  return (Index == 0 || Index == ResVT.getVectorNumElements());
+  return (Index == 0 || Index == ResVT.getVectorMinNumElements());
 }
 
 /// Turn vector tests of the signbit in the form of:
@@ -14321,6 +14344,7 @@ static SDValue performConcatVectorsCombine(SDNode *N,
 static SDValue
 performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                               SelectionDAG &DAG) {
+  SDLoc DL(N);
   SDValue Vec = N->getOperand(0);
   SDValue SubVec = N->getOperand(1);
   uint64_t IdxVal = N->getConstantOperandVal(2);
@@ -14346,7 +14370,6 @@ performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
   // Fold insert_subvector -> concat_vectors
   // insert_subvector(Vec,Sub,lo) -> concat_vectors(Sub,extract(Vec,hi))
   // insert_subvector(Vec,Sub,hi) -> concat_vectors(extract(Vec,lo),Sub)
-  SDLoc DL(N);
   SDValue Lo, Hi;
   if (IdxVal == 0) {
     Lo = SubVec;

diff  --git a/llvm/test/CodeGen/AArch64/sve-insert-vector.ll b/llvm/test/CodeGen/AArch64/sve-insert-vector.ll
index 58034406be46..594b3e0b2f8b 100644
--- a/llvm/test/CodeGen/AArch64/sve-insert-vector.ll
+++ b/llvm/test/CodeGen/AArch64/sve-insert-vector.ll
@@ -501,6 +501,80 @@ define <vscale x 8 x bfloat> @insert_nxv8bf16_v8bf16(<vscale x 8 x bfloat> %sv0,
   ret <vscale x 8 x bfloat> %v0
 }
 
+; Test predicate inserts of half size.
+define <vscale x 16 x i1> @insert_nxv16i1_nxv8i1_0(<vscale x 16 x i1> %vec, <vscale x 8 x i1> %sv) {
+; CHECK-LABEL: insert_nxv16i1_nxv8i1_0:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    uzp1 p0.b, p1.b, p0.b
+; CHECK-NEXT:    ret
+  %v0 = call <vscale x 16 x i1> @llvm.experimental.vector.insert.nx16i1.nxv8i1(<vscale x 16 x i1> %vec, <vscale x 8 x i1> %sv, i64 0)
+  ret <vscale x 16 x i1> %v0
+}
+
+define <vscale x 16 x i1> @insert_nxv16i1_nxv8i1_8(<vscale x 16 x i1> %vec, <vscale x 8 x i1> %sv) {
+; CHECK-LABEL: insert_nxv16i1_nxv8i1_8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    punpklo p0.h, p0.b
+; CHECK-NEXT:    uzp1 p0.b, p0.b, p1.b
+; CHECK-NEXT:    ret
+  %v0 = call <vscale x 16 x i1> @llvm.experimental.vector.insert.nx16i1.nxv8i1(<vscale x 16 x i1> %vec, <vscale x 8 x i1> %sv, i64 8)
+  ret <vscale x 16 x i1> %v0
+}
+
+; Test predicate inserts of less than half the size.
+define <vscale x 16 x i1> @insert_nxv16i1_nxv4i1_0(<vscale x 16 x i1> %vec, <vscale x 4 x i1> %sv) {
+; CHECK-LABEL: insert_nxv16i1_nxv4i1_0:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    punpklo p2.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    punpkhi p2.h, p2.b
+; CHECK-NEXT:    uzp1 p1.h, p1.h, p2.h
+; CHECK-NEXT:    uzp1 p0.b, p1.b, p0.b
+; CHECK-NEXT:    ret
+  %v0 = call <vscale x 16 x i1> @llvm.experimental.vector.insert.nx16i1.nxv4i1(<vscale x 16 x i1> %vec, <vscale x 4 x i1> %sv, i64 0)
+  ret <vscale x 16 x i1> %v0
+}
+
+define <vscale x 16 x i1> @insert_nxv16i1_nxv4i1_12(<vscale x 16 x i1> %vec, <vscale x 4 x i1> %sv) {
+; CHECK-LABEL: insert_nxv16i1_nxv4i1_12:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    punpkhi p2.h, p0.b
+; CHECK-NEXT:    punpklo p0.h, p0.b
+; CHECK-NEXT:    punpklo p2.h, p2.b
+; CHECK-NEXT:    uzp1 p1.h, p2.h, p1.h
+; CHECK-NEXT:    uzp1 p0.b, p0.b, p1.b
+; CHECK-NEXT:    ret
+  %v0 = call <vscale x 16 x i1> @llvm.experimental.vector.insert.nx16i1.nxv4i1(<vscale x 16 x i1> %vec, <vscale x 4 x i1> %sv, i64 12)
+  ret <vscale x 16 x i1> %v0
+}
+
+; Test predicate insert into undef/zero
+define <vscale x 16 x i1> @insert_nxv16i1_nxv4i1_into_zero(<vscale x 4 x i1> %sv) {
+; CHECK-LABEL: insert_nxv16i1_nxv4i1_into_zero:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    pfalse p1.b
+; CHECK-NEXT:    punpklo p2.h, p1.b
+; CHECK-NEXT:    punpkhi p1.h, p1.b
+; CHECK-NEXT:    punpkhi p2.h, p2.b
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p2.h
+; CHECK-NEXT:    uzp1 p0.b, p0.b, p1.b
+; CHECK-NEXT:    ret
+  %v0 = call <vscale x 16 x i1> @llvm.experimental.vector.insert.nx16i1.nxv4i1(<vscale x 16 x i1> zeroinitializer, <vscale x 4 x i1> %sv, i64 0)
+  ret <vscale x 16 x i1> %v0
+}
+
+define <vscale x 16 x i1> @insert_nxv16i1_nxv4i1_into_poison(<vscale x 4 x i1> %sv) {
+; CHECK-LABEL: insert_nxv16i1_nxv4i1_into_poison:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p0.h
+; CHECK-NEXT:    uzp1 p0.b, p0.b, p0.b
+; CHECK-NEXT:    ret
+  %v0 = call <vscale x 16 x i1> @llvm.experimental.vector.insert.nx16i1.nxv4i1(<vscale x 16 x i1> poison, <vscale x 4 x i1> %sv, i64 0)
+  ret <vscale x 16 x i1> %v0
+}
+
+
 declare <vscale x 3 x i32> @llvm.experimental.vector.insert.nxv3i32.nxv2i32(<vscale x 3 x i32>, <vscale x 2 x i32>, i64)
 declare <vscale x 3 x float> @llvm.experimental.vector.insert.nxv3f32.nxv2f32(<vscale x 3 x float>, <vscale x 2 x float>, i64)
 declare <vscale x 6 x i32> @llvm.experimental.vector.insert.nxv6i32.nxv2i32(<vscale x 6 x i32>, <vscale x 2 x i32>, i64)
@@ -511,3 +585,6 @@ declare <vscale x 8 x bfloat> @llvm.experimental.vector.insert.nxv8bf16.v8bf16(<
 declare <vscale x 4 x bfloat> @llvm.experimental.vector.insert.nxv4bf16.nxv4bf16(<vscale x 4 x bfloat>, <vscale x 4 x bfloat>, i64)
 declare <vscale x 4 x bfloat> @llvm.experimental.vector.insert.nxv4bf16.v4bf16(<vscale x 4 x bfloat>, <4 x bfloat>, i64)
 declare <vscale x 2 x bfloat> @llvm.experimental.vector.insert.nxv2bf16.nxv2bf16(<vscale x 2 x bfloat>, <vscale x 2 x bfloat>, i64)
+
+declare <vscale x 16 x i1> @llvm.experimental.vector.insert.nx16i1.nxv4i1(<vscale x 16 x i1>, <vscale x 4 x i1>, i64)
+declare <vscale x 16 x i1> @llvm.experimental.vector.insert.nx16i1.nxv8i1(<vscale x 16 x i1>, <vscale x 8 x i1>, i64)


        


More information about the llvm-commits mailing list