[llvm] [LLVM][SVE] Improve legalisation of fixed length get.active.lane.mask (PR #90213)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 26 07:17:33 PDT 2024


https://github.com/paulwalker-arm updated https://github.com/llvm/llvm-project/pull/90213

>From 7b7fb40e1a41e9fa14eb872c36ccd0577b3abd2e Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Thu, 25 Apr 2024 13:12:51 +0100
Subject: [PATCH] [LLVM][SVE] Improve legalisation of fixed length
 get.active.lane.mask operations.

We are effectively performing type and operation legalisation very
early within the code generation flow. This results in worse code
quality because the DAG is not in canonical form, which DAGCombiner
corrects through the introduction of operations that are not legal.

This patchs splits and moves the code to where type and operation
legalisation is typically implemented.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 82 ++++++++++---------
 llvm/test/CodeGen/AArch64/active_lane_mask.ll | 36 ++++----
 2 files changed, 63 insertions(+), 55 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8e9782c1930c3c..f7e25d53e94262 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1674,6 +1674,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv16i1, MVT::nxv16i8);
 
     setOperationAction(ISD::VSCALE, MVT::i32, Custom);
+
+    for (auto VT : {MVT::v16i1, MVT::v8i1, MVT::v4i1, MVT::v2i1})
+      setOperationAction(ISD::INTRINSIC_WO_CHAIN, VT, Custom);
   }
 
   if (Subtarget->hasMOPS() && Subtarget->hasMTE()) {
@@ -5686,8 +5689,24 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
   case Intrinsic::get_active_lane_mask: {
     SDValue ID =
         DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, dl, MVT::i64);
-    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, Op.getValueType(), ID,
-                       Op.getOperand(1), Op.getOperand(2));
+
+    EVT VT = Op.getValueType();
+    if (VT.isScalableVector())
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID, Op.getOperand(1),
+                         Op.getOperand(2));
+
+    // We can use the SVE whilelo instruction to lower this intrinsic by
+    // creating the appropriate sequence of scalable vector operations and
+    // then extracting a fixed-width subvector from the scalable vector.
+
+    EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+    EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
+
+    SDValue Mask = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, WhileVT, ID,
+                               Op.getOperand(1), Op.getOperand(2));
+    SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, dl, ContainerVT, Mask);
+    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, MaskAsInt,
+                       DAG.getVectorIdxConstant(0, dl));
   }
   case Intrinsic::aarch64_neon_uaddlv: {
     EVT OpVT = Op.getOperand(1).getValueType();
@@ -20462,39 +20481,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
   switch (IID) {
   default:
     break;
-  case Intrinsic::get_active_lane_mask: {
-    SDValue Res = SDValue();
-    EVT VT = N->getValueType(0);
-    if (VT.isFixedLengthVector()) {
-      // We can use the SVE whilelo instruction to lower this intrinsic by
-      // creating the appropriate sequence of scalable vector operations and
-      // then extracting a fixed-width subvector from the scalable vector.
-
-      SDLoc DL(N);
-      SDValue ID =
-          DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
-
-      EVT WhileVT = EVT::getVectorVT(
-          *DAG.getContext(), MVT::i1,
-          ElementCount::getScalable(VT.getVectorNumElements()));
-
-      // Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
-      EVT PromVT = getPromotedVTForPredicate(WhileVT);
-
-      // Get the fixed-width equivalent of PromVT for extraction.
-      EVT ExtVT =
-          EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
-                           VT.getVectorElementCount());
-
-      Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
-                        N->getOperand(1), N->getOperand(2));
-      Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
-      Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
-                        DAG.getConstant(0, DL, MVT::i64));
-      Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
-    }
-    return Res;
-  }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
@@ -25568,8 +25554,6 @@ void AArch64TargetLowering::ReplaceNodeResults(
     return;
   case ISD::INTRINSIC_WO_CHAIN: {
     EVT VT = N->getValueType(0);
-    assert((VT == MVT::i8 || VT == MVT::i16) &&
-           "custom lowering for unexpected type");
 
     Intrinsic::ID IntID =
         static_cast<Intrinsic::ID>(N->getConstantOperandVal(0));
@@ -25577,6 +25561,8 @@ void AArch64TargetLowering::ReplaceNodeResults(
     default:
       return;
     case Intrinsic::aarch64_sve_clasta_n: {
+      assert((VT == MVT::i8 || VT == MVT::i16) &&
+             "custom lowering for unexpected type");
       SDLoc DL(N);
       auto Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, N->getOperand(2));
       auto V = DAG.getNode(AArch64ISD::CLASTA_N, DL, MVT::i32,
@@ -25585,6 +25571,8 @@ void AArch64TargetLowering::ReplaceNodeResults(
       return;
     }
     case Intrinsic::aarch64_sve_clastb_n: {
+      assert((VT == MVT::i8 || VT == MVT::i16) &&
+             "custom lowering for unexpected type");
       SDLoc DL(N);
       auto Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, N->getOperand(2));
       auto V = DAG.getNode(AArch64ISD::CLASTB_N, DL, MVT::i32,
@@ -25593,6 +25581,8 @@ void AArch64TargetLowering::ReplaceNodeResults(
       return;
     }
     case Intrinsic::aarch64_sve_lasta: {
+      assert((VT == MVT::i8 || VT == MVT::i16) &&
+             "custom lowering for unexpected type");
       SDLoc DL(N);
       auto V = DAG.getNode(AArch64ISD::LASTA, DL, MVT::i32,
                            N->getOperand(1), N->getOperand(2));
@@ -25600,12 +25590,30 @@ void AArch64TargetLowering::ReplaceNodeResults(
       return;
     }
     case Intrinsic::aarch64_sve_lastb: {
+      assert((VT == MVT::i8 || VT == MVT::i16) &&
+             "custom lowering for unexpected type");
       SDLoc DL(N);
       auto V = DAG.getNode(AArch64ISD::LASTB, DL, MVT::i32,
                            N->getOperand(1), N->getOperand(2));
       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
       return;
     }
+    case Intrinsic::get_active_lane_mask: {
+      if (!VT.isFixedLengthVector())
+        return;
+      if (VT.getVectorElementType() != MVT::i1)
+        return;
+
+      // NOTE: Only trivial type promotion is supported.
+      EVT NewVT = getTypeToTransformTo(*DAG.getContext(), VT);
+      if (NewVT.getVectorNumElements() != VT.getVectorNumElements())
+        return;
+
+      SDLoc DL(N);
+      auto V = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NewVT, N->ops());
+      Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
+      return;
+    }
     }
   }
   case ISD::READ_REGISTER: {
diff --git a/llvm/test/CodeGen/AArch64/active_lane_mask.ll b/llvm/test/CodeGen/AArch64/active_lane_mask.ll
index 43122c8c953fc7..bd5d076d1ba82e 100644
--- a/llvm/test/CodeGen/AArch64/active_lane_mask.ll
+++ b/llvm/test/CodeGen/AArch64/active_lane_mask.ll
@@ -353,9 +353,9 @@ define <16 x i1> @lane_mask_v16i1_i32(i32 %index, i32 %TC) {
 define <8 x i1> @lane_mask_v8i1_i32(i32 %index, i32 %TC) {
 ; CHECK-LABEL: lane_mask_v8i1_i32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    whilelo p0.h, w0, w1
-; CHECK-NEXT:    mov z0.h, p0/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    xtn v0.8b, v0.8h
+; CHECK-NEXT:    whilelo p0.b, w0, w1
+; CHECK-NEXT:    mov z0.b, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; CHECK-NEXT:    ret
   %active.lane.mask = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 %index, i32 %TC)
   ret <8 x i1> %active.lane.mask
@@ -364,9 +364,9 @@ define <8 x i1> @lane_mask_v8i1_i32(i32 %index, i32 %TC) {
 define <4 x i1> @lane_mask_v4i1_i32(i32 %index, i32 %TC) {
 ; CHECK-LABEL: lane_mask_v4i1_i32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    whilelo p0.s, w0, w1
-; CHECK-NEXT:    mov z0.s, p0/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    whilelo p0.h, w0, w1
+; CHECK-NEXT:    mov z0.h, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; CHECK-NEXT:    ret
   %active.lane.mask = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 %index, i32 %TC)
   ret <4 x i1> %active.lane.mask
@@ -375,9 +375,9 @@ define <4 x i1> @lane_mask_v4i1_i32(i32 %index, i32 %TC) {
 define <2 x i1> @lane_mask_v2i1_i32(i32 %index, i32 %TC) {
 ; CHECK-LABEL: lane_mask_v2i1_i32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    whilelo p0.d, w0, w1
-; CHECK-NEXT:    mov z0.d, p0/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    xtn v0.2s, v0.2d
+; CHECK-NEXT:    whilelo p0.s, w0, w1
+; CHECK-NEXT:    mov z0.s, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; CHECK-NEXT:    ret
   %active.lane.mask = call <2 x i1> @llvm.get.active.lane.mask.v2i1.i32(i32 %index, i32 %TC)
   ret <2 x i1> %active.lane.mask
@@ -397,9 +397,9 @@ define <16 x i1> @lane_mask_v16i1_i64(i64 %index, i64 %TC) {
 define <8 x i1> @lane_mask_v8i1_i64(i64 %index, i64 %TC) {
 ; CHECK-LABEL: lane_mask_v8i1_i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    whilelo p0.h, x0, x1
-; CHECK-NEXT:    mov z0.h, p0/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    xtn v0.8b, v0.8h
+; CHECK-NEXT:    whilelo p0.b, x0, x1
+; CHECK-NEXT:    mov z0.b, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; CHECK-NEXT:    ret
   %active.lane.mask = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i64(i64 %index, i64 %TC)
   ret <8 x i1> %active.lane.mask
@@ -408,9 +408,9 @@ define <8 x i1> @lane_mask_v8i1_i64(i64 %index, i64 %TC) {
 define <4 x i1> @lane_mask_v4i1_i64(i64 %index, i64 %TC) {
 ; CHECK-LABEL: lane_mask_v4i1_i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    whilelo p0.s, x0, x1
-; CHECK-NEXT:    mov z0.s, p0/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    whilelo p0.h, x0, x1
+; CHECK-NEXT:    mov z0.h, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; CHECK-NEXT:    ret
   %active.lane.mask = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i64(i64 %index, i64 %TC)
   ret <4 x i1> %active.lane.mask
@@ -419,9 +419,9 @@ define <4 x i1> @lane_mask_v4i1_i64(i64 %index, i64 %TC) {
 define <2 x i1> @lane_mask_v2i1_i64(i64 %index, i64 %TC) {
 ; CHECK-LABEL: lane_mask_v2i1_i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    whilelo p0.d, x0, x1
-; CHECK-NEXT:    mov z0.d, p0/z, #-1 // =0xffffffffffffffff
-; CHECK-NEXT:    xtn v0.2s, v0.2d
+; CHECK-NEXT:    whilelo p0.s, x0, x1
+; CHECK-NEXT:    mov z0.s, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; CHECK-NEXT:    ret
   %active.lane.mask = call <2 x i1> @llvm.get.active.lane.mask.v2i1.i64(i64 %index, i64 %TC)
   ret <2 x i1> %active.lane.mask



More information about the llvm-commits mailing list