[llvm] e5edc1b - [AArch64][SVE] Ensure PTEST operands have type nxv16i1

Rosie Sumpter via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 12 01:34:26 PDT 2022


Author: Rosie Sumpter
Date: 2022-07-12T09:27:59+01:00
New Revision: e5edc1b5eecfb8abc4e6d4d385da7ed0b456579c

URL: https://github.com/llvm/llvm-project/commit/e5edc1b5eecfb8abc4e6d4d385da7ed0b456579c
DIFF: https://github.com/llvm/llvm-project/commit/e5edc1b5eecfb8abc4e6d4d385da7ed0b456579c.diff

LOG: [AArch64][SVE] Ensure PTEST operands have type nxv16i1

Currently any legal predicate types will be pattern-matched when
creating a PTEST instruction. This could be a problem in future since
PTEST always uses the .B specifier for the operand, but it is not
always guaranteed that the extra lanes of unpacked types (e.g. nxv4i1)
are zero. This patch ensures the operands of PTEST are type nxv16i1,
where the undef lanes are set to zero.

Differential Revision: https://reviews.llvm.org/D129282/

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/lib/Target/AArch64/SVEInstrFormats.td
    llvm/test/CodeGen/AArch64/sve-setcc.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8108b77978e5b..447ad10ddf228 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -237,6 +237,39 @@ static bool isMergePassthruOpcode(unsigned Opc) {
   }
 }
 
+// Returns true if inactive lanes are known to be zeroed by construction.
+static bool isZeroingInactiveLanes(SDValue Op) {
+  switch (Op.getOpcode()) {
+  default:
+    // We guarantee i1 splat_vectors to zero the other lanes by
+    // implementing it with ptrue and possibly a punpklo for nxv1i1.
+    if (ISD::isConstantSplatVectorAllOnes(Op.getNode()))
+      return true;
+    return false;
+  case AArch64ISD::PTRUE:
+  case AArch64ISD::SETCC_MERGE_ZERO:
+    return true;
+  case ISD::INTRINSIC_WO_CHAIN:
+    switch (Op.getConstantOperandVal(0)) {
+    default:
+      return false;
+    case Intrinsic::aarch64_sve_ptrue:
+    case Intrinsic::aarch64_sve_pnext:
+    case Intrinsic::aarch64_sve_cmpeq_wide:
+    case Intrinsic::aarch64_sve_cmpne_wide:
+    case Intrinsic::aarch64_sve_cmpge_wide:
+    case Intrinsic::aarch64_sve_cmpgt_wide:
+    case Intrinsic::aarch64_sve_cmplt_wide:
+    case Intrinsic::aarch64_sve_cmple_wide:
+    case Intrinsic::aarch64_sve_cmphs_wide:
+    case Intrinsic::aarch64_sve_cmphi_wide:
+    case Intrinsic::aarch64_sve_cmplo_wide:
+    case Intrinsic::aarch64_sve_cmpls_wide:
+      return true;
+    }
+  }
+}
+
 AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
                                              const AArch64Subtarget &STI)
     : TargetLowering(TM), Subtarget(&STI) {
@@ -4368,16 +4401,18 @@ static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT,
                      DAG.getTargetConstant(Pattern, DL, MVT::i32));
 }
 
-SDValue AArch64TargetLowering::getSVEPredicateBitCast(EVT VT, SDValue Op,
-                                                      SelectionDAG &DAG) const {
+// Returns a safe bitcast between two scalable vector predicates, where
+// any newly created lanes from a widening bitcast are defined as zero.
+static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
   SDLoc DL(Op);
   EVT InVT = Op.getValueType();
 
   assert(InVT.getVectorElementType() == MVT::i1 &&
          VT.getVectorElementType() == MVT::i1 &&
          "Expected a predicate-to-predicate bitcast");
-  assert(VT.isScalableVector() && isTypeLegal(VT) &&
-         InVT.isScalableVector() && isTypeLegal(InVT) &&
+  assert(VT.isScalableVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
+         InVT.isScalableVector() &&
+         DAG.getTargetLoweringInfo().isTypeLegal(InVT) &&
          "Only expect to cast between legal scalable predicate types!");
 
   // Return the operand if the cast isn't changing type,
@@ -4396,33 +4431,8 @@ SDValue AArch64TargetLowering::getSVEPredicateBitCast(EVT VT, SDValue Op,
 
   // Check if the other lanes are already known to be zeroed by
   // construction.
-  switch (Op.getOpcode()) {
-  default:
-    // We guarantee i1 splat_vectors to zero the other lanes by
-    // implementing it with ptrue and possibly a punpklo for nxv1i1.
-    if (ISD::isConstantSplatVectorAllOnes(Op.getNode()))
-      return Reinterpret;
-    break;
-  case AArch64ISD::SETCC_MERGE_ZERO:
+  if (isZeroingInactiveLanes(Op))
     return Reinterpret;
-  case ISD::INTRINSIC_WO_CHAIN:
-    switch (Op.getConstantOperandVal(0)) {
-    default:
-      break;
-    case Intrinsic::aarch64_sve_ptrue:
-    case Intrinsic::aarch64_sve_cmpeq_wide:
-    case Intrinsic::aarch64_sve_cmpne_wide:
-    case Intrinsic::aarch64_sve_cmpge_wide:
-    case Intrinsic::aarch64_sve_cmpgt_wide:
-    case Intrinsic::aarch64_sve_cmplt_wide:
-    case Intrinsic::aarch64_sve_cmple_wide:
-    case Intrinsic::aarch64_sve_cmphs_wide:
-    case Intrinsic::aarch64_sve_cmphi_wide:
-    case Intrinsic::aarch64_sve_cmplo_wide:
-    case Intrinsic::aarch64_sve_cmpls_wide:
-      return Reinterpret;
-    }
-  }
 
   // Zero the newly introduced lanes.
   SDValue Mask = DAG.getConstant(1, DL, InVT);
@@ -16164,12 +16174,24 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
   assert(Op.getValueType().isScalableVector() &&
          TLI.isTypeLegal(Op.getValueType()) &&
          "Expected legal scalable vector type!");
+  assert(Op.getValueType() == Pg.getValueType() &&
+         "Expected same type for PTEST operands");
 
   // Ensure target specific opcodes are using legal type.
   EVT OutVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
   SDValue TVal = DAG.getConstant(1, DL, OutVT);
   SDValue FVal = DAG.getConstant(0, DL, OutVT);
 
+  // Ensure operands have type nxv16i1.
+  if (Op.getValueType() != MVT::nxv16i1) {
+    if ((Cond == AArch64CC::ANY_ACTIVE || Cond == AArch64CC::NONE_ACTIVE) &&
+        isZeroingInactiveLanes(Op))
+      Pg = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Pg);
+    else
+      Pg = getSVEPredicateBitCast(MVT::nxv16i1, Pg, DAG);
+    Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Op);
+  }
+
   // Set condition code (CC) flags.
   SDValue Test = DAG.getNode(AArch64ISD::PTEST, DL, MVT::Other, Pg, Op);
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 48a559b4352ac..e02b5e56fd2e9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1154,10 +1154,6 @@ class AArch64TargetLowering : public TargetLowering {
   // This function does not handle predicate bitcasts.
   SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
 
-  // Returns a safe bitcast between two scalable vector predicates, where
-  // any newly created lanes from a widening bitcast are defined as zero.
-  SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
-
   bool isConstantUnsignedBitfieldExtractLegal(unsigned Opc, LLT Ty1,
                                               LLT Ty2) const override;
 };

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 58ef4b3e09bc4..c66f9cfd9c226 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -778,7 +778,7 @@ let Predicates = [HasSVEorSME] in {
   defm BRKB_PPmP  : sve_int_break_m<0b101, "brkb",  int_aarch64_sve_brkb>;
   defm BRKBS_PPzP : sve_int_break_z<0b110, "brkbs", null_frag>;
 
-  def PTEST_PP : sve_int_ptest<0b010000, "ptest">;
+  def PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest>;
   defm PFALSE  : sve_int_pfalse<0b000000, "pfalse">;
   defm PFIRST  : sve_int_pfirst<0b00000, "pfirst", int_aarch64_sve_pfirst>;
   defm PNEXT   : sve_int_pnext<0b00110, "pnext", int_aarch64_sve_pnext>;
@@ -2131,17 +2131,6 @@ let Predicates = [HasSVEorSME] in {
     def STR_ZZZZXI : Pseudo<(outs), (ins ZZZZ_b:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>;
   }
 
-  def : Pat<(AArch64ptest (nxv16i1 PPR:$pg), (nxv16i1 PPR:$src)),
-            (PTEST_PP PPR:$pg, PPR:$src)>;
-  def : Pat<(AArch64ptest (nxv8i1 PPR:$pg), (nxv8i1 PPR:$src)),
-            (PTEST_PP PPR:$pg, PPR:$src)>;
-  def : Pat<(AArch64ptest (nxv4i1 PPR:$pg), (nxv4i1 PPR:$src)),
-            (PTEST_PP PPR:$pg, PPR:$src)>;
-  def : Pat<(AArch64ptest (nxv2i1 PPR:$pg), (nxv2i1 PPR:$src)),
-            (PTEST_PP PPR:$pg, PPR:$src)>;
-  def : Pat<(AArch64ptest (nxv1i1 PPR:$pg), (nxv1i1 PPR:$src)),
-            (PTEST_PP PPR:$pg, PPR:$src)>;
-
   let AddedComplexity = 1 in {
   class LD1RPat<ValueType vt, SDPatternOperator operator,
                 Instruction load, Instruction ptrue, ValueType index_vt, ComplexPattern CP, Operand immtype> :

diff  --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 80e38d002b0fe..7cdd4c4af95ec 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -650,11 +650,11 @@ multiclass sve_int_pfalse<bits<6> opc, string asm> {
   def : Pat<(nxv1i1 immAllZerosV), (!cast<Instruction>(NAME))>;
 }
 
-class sve_int_ptest<bits<6> opc, string asm>
+class sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op>
 : I<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
   asm, "\t$Pg, $Pn",
   "",
-  []>, Sched<[]> {
+  [(op (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>, Sched<[]> {
   bits<4> Pg;
   bits<4> Pn;
   let Inst{31-24} = 0b00100101;

diff  --git a/llvm/test/CodeGen/AArch64/sve-setcc.ll b/llvm/test/CodeGen/AArch64/sve-setcc.ll
index 8d7aae877f6af..60ee9b34d1760 100644
--- a/llvm/test/CodeGen/AArch64/sve-setcc.ll
+++ b/llvm/test/CodeGen/AArch64/sve-setcc.ll
@@ -51,7 +51,10 @@ if.end:
 define void @sve_cmplt_setcc_hslo(<vscale x 8 x i16>* %out, <vscale x 8 x i16> %in, <vscale x 8 x i1> %pg) {
 ; CHECK-LABEL: sve_cmplt_setcc_hslo:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    cmplt p1.h, p0/z, z0.h, #0
+; CHECK-NEXT:    ptrue p1.h
+; CHECK-NEXT:    cmplt p2.h, p0/z, z0.h, #0
+; CHECK-NEXT:    and p1.b, p0/z, p0.b, p1.b
+; CHECK-NEXT:    ptest p1, p2.b
 ; CHECK-NEXT:    b.hs .LBB2_2
 ; CHECK-NEXT:  // %bb.1: // %if.then
 ; CHECK-NEXT:    st1h { z0.h }, p0, [x0]


        


More information about the llvm-commits mailing list