[clang-tools-extra] [llvm] [clang] [AArch64][SVE2] Lower OR to SLI/SRI (PR #77555)

Usman Nadeem via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 11 11:24:42 PST 2024


https://github.com/UsmanNadeem updated https://github.com/llvm/llvm-project/pull/77555

>From 7eeacff38b6d95fb2eb0fe13cad660801e7982fd Mon Sep 17 00:00:00 2001
From: "Nadeem, Usman" <mnadeem at quicinc.com>
Date: Tue, 9 Jan 2024 20:20:10 -0800
Subject: [PATCH 1/2] [AArch64][SVE2] Lower OR to SLI/SRI

Code builds on NEON code and the tests are adapted from NEON tests
minus the tests for illegal types.

Change-Id: I11325949700fb7433f948bbe3e82dbc71696aecc
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 152 ++++++----
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |   4 +-
 llvm/lib/Target/AArch64/AArch64Subtarget.h    |   1 +
 llvm/test/CodeGen/AArch64/sve2-sli-sri.ll     | 263 ++++++++++++++++++
 4 files changed, 357 insertions(+), 63 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sve2-sli-sri.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 102fd0c3dae2ab..269dde004bea78 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1358,6 +1358,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
       if (!Subtarget->isLittleEndian())
         setOperationAction(ISD::BITCAST, VT, Expand);
+
+      if (Subtarget->hasSVE2orSME())
+        // For SLI/SRI.
+        setOperationAction(ISD::OR, VT, Custom);
     }
 
     // Illegal unpacked integer vector types.
@@ -5411,7 +5415,9 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
   }
 
   case Intrinsic::aarch64_neon_vsri:
-  case Intrinsic::aarch64_neon_vsli: {
+  case Intrinsic::aarch64_neon_vsli:
+  case Intrinsic::aarch64_sve_sri:
+  case Intrinsic::aarch64_sve_sli: {
     EVT Ty = Op.getValueType();
 
     if (!Ty.isVector())
@@ -5419,7 +5425,8 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
 
     assert(Op.getConstantOperandVal(3) <= Ty.getScalarSizeInBits());
 
-    bool IsShiftRight = IntNo == Intrinsic::aarch64_neon_vsri;
+    bool IsShiftRight = IntNo == Intrinsic::aarch64_neon_vsri ||
+                        IntNo == Intrinsic::aarch64_sve_sri;
     unsigned Opcode = IsShiftRight ? AArch64ISD::VSRI : AArch64ISD::VSLI;
     return DAG.getNode(Opcode, dl, Ty, Op.getOperand(1), Op.getOperand(2),
                        Op.getOperand(3));
@@ -12544,6 +12551,53 @@ static bool isAllConstantBuildVector(const SDValue &PotentialBVec,
   return true;
 }
 
+static bool isAllInactivePredicate(SDValue N) {
+  // Look through cast.
+  while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST)
+    N = N.getOperand(0);
+
+  return ISD::isConstantSplatVectorAllZeros(N.getNode());
+}
+
+static bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) {
+  unsigned NumElts = N.getValueType().getVectorMinNumElements();
+
+  // Look through cast.
+  while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST) {
+    N = N.getOperand(0);
+    // When reinterpreting from a type with fewer elements the "new" elements
+    // are not active, so bail if they're likely to be used.
+    if (N.getValueType().getVectorMinNumElements() < NumElts)
+      return false;
+  }
+
+  if (ISD::isConstantSplatVectorAllOnes(N.getNode()))
+    return true;
+
+  // "ptrue p.<ty>, all" can be considered all active when <ty> is the same size
+  // or smaller than the implicit element type represented by N.
+  // NOTE: A larger element count implies a smaller element type.
+  if (N.getOpcode() == AArch64ISD::PTRUE &&
+      N.getConstantOperandVal(0) == AArch64SVEPredPattern::all)
+    return N.getValueType().getVectorMinNumElements() >= NumElts;
+
+  // If we're compiling for a specific vector-length, we can check if the
+  // pattern's VL equals that of the scalable vector at runtime.
+  if (N.getOpcode() == AArch64ISD::PTRUE) {
+    const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
+    unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
+    unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
+    if (MaxSVESize && MinSVESize == MaxSVESize) {
+      unsigned VScale = MaxSVESize / AArch64::SVEBitsPerBlock;
+      unsigned PatNumElts =
+          getNumElementsFromSVEPredPattern(N.getConstantOperandVal(0));
+      return PatNumElts == (NumElts * VScale);
+    }
+  }
+
+  return false;
+}
+
 // Attempt to form a vector S[LR]I from (or (and X, BvecC1), (lsl Y, C2)),
 // to (SLI X, Y, C2), where X and Y have matching vector types, BvecC1 is a
 // BUILD_VECTORs with constant element C1, C2 is a constant, and:
@@ -12569,32 +12623,52 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
   // Is one of the operands an AND or a BICi? The AND may have been optimised to
   // a BICi in order to use an immediate instead of a register.
   // Is the other operand an shl or lshr? This will have been turned into:
-  // AArch64ISD::VSHL vector, #shift or AArch64ISD::VLSHR vector, #shift.
+  // AArch64ISD::VSHL vector, #shift or AArch64ISD::VLSHR vector, #shift
+  // or (AArch64ISD::SHL_PRED || AArch64ISD::SRL_PRED) mask, vector, #shiftVec.
   if ((FirstOpc == ISD::AND || FirstOpc == AArch64ISD::BICi) &&
-      (SecondOpc == AArch64ISD::VSHL || SecondOpc == AArch64ISD::VLSHR)) {
+      (SecondOpc == AArch64ISD::VSHL || SecondOpc == AArch64ISD::VLSHR ||
+       SecondOpc == AArch64ISD::SHL_PRED ||
+       SecondOpc == AArch64ISD::SRL_PRED)) {
     And = FirstOp;
     Shift = SecondOp;
 
   } else if ((SecondOpc == ISD::AND || SecondOpc == AArch64ISD::BICi) &&
-             (FirstOpc == AArch64ISD::VSHL || FirstOpc == AArch64ISD::VLSHR)) {
+             (FirstOpc == AArch64ISD::VSHL || FirstOpc == AArch64ISD::VLSHR ||
+              FirstOpc == AArch64ISD::SHL_PRED ||
+              FirstOpc == AArch64ISD::SRL_PRED)) {
     And = SecondOp;
     Shift = FirstOp;
   } else
     return SDValue();
 
   bool IsAnd = And.getOpcode() == ISD::AND;
-  bool IsShiftRight = Shift.getOpcode() == AArch64ISD::VLSHR;
-
-  // Is the shift amount constant?
-  ConstantSDNode *C2node = dyn_cast<ConstantSDNode>(Shift.getOperand(1));
-  if (!C2node)
+  bool IsShiftRight = Shift.getOpcode() == AArch64ISD::VLSHR ||
+                      Shift.getOpcode() == AArch64ISD::SRL_PRED;
+  bool ShiftHasPredOp = Shift.getOpcode() == AArch64ISD::SHL_PRED ||
+                        Shift.getOpcode() == AArch64ISD::SRL_PRED;
+
+  // Is the shift amount constant and are all lanes active?
+  uint64_t C2;
+  if (ShiftHasPredOp) {
+    if (!isAllActivePredicate(DAG, Shift.getOperand(0)))
+      return SDValue();
+    APInt C;
+    if (!ISD::isConstantSplatVector(Shift.getOperand(2).getNode(), C))
+      return SDValue();
+    C2 = C.getZExtValue();
+  } else if (ConstantSDNode *C2node =
+                 dyn_cast<ConstantSDNode>(Shift.getOperand(1)))
+    C2 = C2node->getZExtValue();
+  else
     return SDValue();
 
   uint64_t C1;
   if (IsAnd) {
     // Is the and mask vector all constant?
-    if (!isAllConstantBuildVector(And.getOperand(1), C1))
+    APInt C;
+    if (!ISD::isConstantSplatVector(And.getOperand(1).getNode(), C))
       return SDValue();
+    C1 = C.getZExtValue();
   } else {
     // Reconstruct the corresponding AND immediate from the two BICi immediates.
     ConstantSDNode *C1nodeImm = dyn_cast<ConstantSDNode>(And.getOperand(1));
@@ -12606,7 +12680,6 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
   // Is C1 == ~(Ones(ElemSizeInBits) << C2) or
   // C1 == ~(Ones(ElemSizeInBits) >> C2), taking into account
   // how much one can shift elements of a particular size?
-  uint64_t C2 = C2node->getZExtValue();
   unsigned ElemSizeInBits = VT.getScalarSizeInBits();
   if (C2 > ElemSizeInBits)
     return SDValue();
@@ -12618,10 +12691,12 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
     return SDValue();
 
   SDValue X = And.getOperand(0);
-  SDValue Y = Shift.getOperand(0);
+  SDValue Y = (ShiftHasPredOp) ? Shift.getOperand(1) : Shift.getOperand(0);
+  SDValue Imm = (ShiftHasPredOp) ? DAG.getTargetConstant(C2, DL, MVT::i32)
+                                 : Shift.getOperand(1);
 
   unsigned Inst = IsShiftRight ? AArch64ISD::VSRI : AArch64ISD::VSLI;
-  SDValue ResultSLI = DAG.getNode(Inst, DL, VT, X, Y, Shift.getOperand(1));
+  SDValue ResultSLI = DAG.getNode(Inst, DL, VT, X, Y, Imm);
 
   LLVM_DEBUG(dbgs() << "aarch64-lower: transformed: \n");
   LLVM_DEBUG(N->dump(&DAG));
@@ -12643,6 +12718,8 @@ SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op,
     return Res;
 
   EVT VT = Op.getValueType();
+  if (VT.isScalableVector())
+    return Op;
 
   SDValue LHS = Op.getOperand(0);
   BuildVectorSDNode *BVN =
@@ -17434,53 +17511,6 @@ static bool isConstantSplatVectorMaskForType(SDNode *N, EVT MemVT) {
   return false;
 }
 
-static bool isAllInactivePredicate(SDValue N) {
-  // Look through cast.
-  while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST)
-    N = N.getOperand(0);
-
-  return ISD::isConstantSplatVectorAllZeros(N.getNode());
-}
-
-static bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) {
-  unsigned NumElts = N.getValueType().getVectorMinNumElements();
-
-  // Look through cast.
-  while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST) {
-    N = N.getOperand(0);
-    // When reinterpreting from a type with fewer elements the "new" elements
-    // are not active, so bail if they're likely to be used.
-    if (N.getValueType().getVectorMinNumElements() < NumElts)
-      return false;
-  }
-
-  if (ISD::isConstantSplatVectorAllOnes(N.getNode()))
-    return true;
-
-  // "ptrue p.<ty>, all" can be considered all active when <ty> is the same size
-  // or smaller than the implicit element type represented by N.
-  // NOTE: A larger element count implies a smaller element type.
-  if (N.getOpcode() == AArch64ISD::PTRUE &&
-      N.getConstantOperandVal(0) == AArch64SVEPredPattern::all)
-    return N.getValueType().getVectorMinNumElements() >= NumElts;
-
-  // If we're compiling for a specific vector-length, we can check if the
-  // pattern's VL equals that of the scalable vector at runtime.
-  if (N.getOpcode() == AArch64ISD::PTRUE) {
-    const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
-    unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
-    unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
-    if (MaxSVESize && MinSVESize == MaxSVESize) {
-      unsigned VScale = MaxSVESize / AArch64::SVEBitsPerBlock;
-      unsigned PatNumElts =
-          getNumElementsFromSVEPredPattern(N.getConstantOperandVal(0));
-      return PatNumElts == (NumElts * VScale);
-    }
-  }
-
-  return false;
-}
-
 static SDValue performReinterpretCastCombine(SDNode *N) {
   SDValue LeafOp = SDValue(N, 0);
   SDValue Op = N->getOperand(0);
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 344a153890631e..da9021f6e0feb5 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -3574,8 +3574,8 @@ let Predicates = [HasSVE2orSME] in {
   defm PMULLT_ZZZ   : sve2_pmul_long<0b1, "pmullt", int_aarch64_sve_pmullt_pair>;
 
   // SVE2 bitwise shift and insert
-  defm SRI_ZZI : sve2_int_bin_shift_imm_right<0b0, "sri", int_aarch64_sve_sri>;
-  defm SLI_ZZI : sve2_int_bin_shift_imm_left< 0b1, "sli", int_aarch64_sve_sli>;
+  defm SRI_ZZI : sve2_int_bin_shift_imm_right<0b0, "sri", AArch64vsri>;
+  defm SLI_ZZI : sve2_int_bin_shift_imm_left< 0b1, "sli", AArch64vsli>;
 
   // SVE2 bitwise shift right and accumulate
   defm SSRA_ZZI  : sve2_int_bin_accum_shift_imm_right<0b00, "ssra",  AArch64ssra>;
diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h
index b17e215e200dea..a131cf8a6f5402 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.h
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h
@@ -394,6 +394,7 @@ class AArch64Subtarget final : public AArch64GenSubtargetInfo {
   void mirFileLoaded(MachineFunction &MF) const override;
 
   bool hasSVEorSME() const { return hasSVE() || hasSME(); }
+  bool hasSVE2orSME() const { return hasSVE2() || hasSME(); }
 
   // Return the known range for the bit length of SVE data registers. A value
   // of 0 means nothing is known about that particular limit beyong what's
diff --git a/llvm/test/CodeGen/AArch64/sve2-sli-sri.ll b/llvm/test/CodeGen/AArch64/sve2-sli-sri.ll
new file mode 100644
index 00000000000000..80999fb1f4864b
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2-sli-sri.ll
@@ -0,0 +1,263 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64 -mattr=+sve < %s -o - | FileCheck --check-prefixes=CHECK,SVE %s
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 < %s -o - | FileCheck --check-prefixes=CHECK,SVE2 %s
+
+define <vscale x 16 x i8> @testLeftGood16x8(<vscale x 16 x i8> %src1, <vscale x 16 x i8> %src2) {
+; SVE-LABEL: testLeftGood16x8:
+; SVE:       // %bb.0:
+; SVE-NEXT:    and z0.b, z0.b, #0x7
+; SVE-NEXT:    lsl z1.b, z1.b, #3
+; SVE-NEXT:    orr z0.d, z0.d, z1.d
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: testLeftGood16x8:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    sli z0.b, z1.b, #3
+; SVE2-NEXT:    ret
+  %and.i = and <vscale x 16 x i8> %src1, splat(i8 7)
+  %vshl_n = shl <vscale x 16 x i8> %src2, splat(i8 3)
+  %result = or <vscale x 16 x i8> %and.i, %vshl_n
+  ret <vscale x 16 x i8> %result
+}
+
+define <vscale x 16 x i8> @testLeftBad16x8(<vscale x 16 x i8> %src1, <vscale x 16 x i8> %src2) {
+; CHECK-LABEL: testLeftBad16x8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z2.b, #-91 // =0xffffffffffffffa5
+; CHECK-NEXT:    lsl z1.b, z1.b, #1
+; CHECK-NEXT:    and z0.d, z0.d, z2.d
+; CHECK-NEXT:    orr z0.d, z0.d, z1.d
+; CHECK-NEXT:    ret
+  %and.i = and <vscale x 16 x i8> %src1, splat(i8 165)
+  %vshl_n = shl <vscale x 16 x i8> %src2, splat(i8 1)
+  %result = or <vscale x 16 x i8> %and.i, %vshl_n
+  ret <vscale x 16 x i8> %result
+}
+
+define <vscale x 16 x i8> @testRightGood16x8(<vscale x 16 x i8> %src1, <vscale x 16 x i8> %src2) {
+; SVE-LABEL: testRightGood16x8:
+; SVE:       // %bb.0:
+; SVE-NEXT:    and z0.b, z0.b, #0xe0
+; SVE-NEXT:    lsr z1.b, z1.b, #3
+; SVE-NEXT:    orr z0.d, z0.d, z1.d
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: testRightGood16x8:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    sri z0.b, z1.b, #3
+; SVE2-NEXT:    ret
+  %and.i = and <vscale x 16 x i8> %src1, splat(i8 224)
+  %vshl_n = lshr <vscale x 16 x i8> %src2, splat(i8 3)
+  %result = or <vscale x 16 x i8> %and.i, %vshl_n
+  ret <vscale x 16 x i8> %result
+}
+
+define <vscale x 16 x i8> @testRightBad16x8(<vscale x 16 x i8> %src1, <vscale x 16 x i8> %src2) {
+; CHECK-LABEL: testRightBad16x8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z2.b, #-91 // =0xffffffffffffffa5
+; CHECK-NEXT:    lsr z1.b, z1.b, #1
+; CHECK-NEXT:    and z0.d, z0.d, z2.d
+; CHECK-NEXT:    orr z0.d, z0.d, z1.d
+; CHECK-NEXT:    ret
+  %and.i = and <vscale x 16 x i8> %src1, splat(i8 165)
+  %vshl_n = lshr <vscale x 16 x i8> %src2, splat(i8 1)
+  %result = or <vscale x 16 x i8> %and.i, %vshl_n
+  ret <vscale x 16 x i8> %result
+}
+
+define <vscale x 8 x i16> @testLeftGood8x16(<vscale x 8 x i16> %src1, <vscale x 8 x i16> %src2) {
+; SVE-LABEL: testLeftGood8x16:
+; SVE:       // %bb.0:
+; SVE-NEXT:    and z0.h, z0.h, #0x3fff
+; SVE-NEXT:    lsl z1.h, z1.h, #14
+; SVE-NEXT:    orr z0.d, z0.d, z1.d
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: testLeftGood8x16:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    sli z0.h, z1.h, #14
+; SVE2-NEXT:    ret
+  %and.i = and <vscale x 8 x i16> %src1, splat(i16 16383)
+  %vshl_n = shl <vscale x 8 x i16> %src2, splat(i16 14)
+  %result = or <vscale x 8 x i16> %and.i, %vshl_n
+  ret <vscale x 8 x i16> %result
+}
+
+define <vscale x 8 x i16> @testLeftBad8x16(<vscale x 8 x i16> %src1, <vscale x 8 x i16> %src2) {
+; CHECK-LABEL: testLeftBad8x16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #16500 // =0x4074
+; CHECK-NEXT:    lsl z1.h, z1.h, #14
+; CHECK-NEXT:    mov z2.h, w8
+; CHECK-NEXT:    and z0.d, z0.d, z2.d
+; CHECK-NEXT:    orr z0.d, z0.d, z1.d
+; CHECK-NEXT:    ret
+  %and.i = and <vscale x 8 x i16> %src1, splat(i16 16500)
+  %vshl_n = shl <vscale x 8 x i16> %src2, splat(i16 14)
+  %result = or <vscale x 8 x i16> %and.i, %vshl_n
+  ret <vscale x 8 x i16> %result
+}
+
+define <vscale x 8 x i16> @testRightGood8x16(<vscale x 8 x i16> %src1, <vscale x 8 x i16> %src2) {
+; SVE-LABEL: testRightGood8x16:
+; SVE:       // %bb.0:
+; SVE-NEXT:    and z0.h, z0.h, #0xfffc
+; SVE-NEXT:    lsr z1.h, z1.h, #14
+; SVE-NEXT:    orr z0.d, z0.d, z1.d
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: testRightGood8x16:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    sri z0.h, z1.h, #14
+; SVE2-NEXT:    ret
+  %and.i = and <vscale x 8 x i16> %src1, splat(i16 65532)
+  %vshl_n = lshr <vscale x 8 x i16> %src2, splat(i16 14)
+  %result = or <vscale x 8 x i16> %and.i, %vshl_n
+  ret <vscale x 8 x i16> %result
+}
+
+define <vscale x 8 x i16> @testRightBad8x16(<vscale x 8 x i16> %src1, <vscale x 8 x i16> %src2) {
+; CHECK-LABEL: testRightBad8x16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #16500 // =0x4074
+; CHECK-NEXT:    lsr z1.h, z1.h, #14
+; CHECK-NEXT:    mov z2.h, w8
+; CHECK-NEXT:    and z0.d, z0.d, z2.d
+; CHECK-NEXT:    orr z0.d, z0.d, z1.d
+; CHECK-NEXT:    ret
+  %and.i = and <vscale x 8 x i16> %src1, splat(i16 16500)
+  %vshl_n = lshr <vscale x 8 x i16> %src2, splat(i16 14)
+  %result = or <vscale x 8 x i16> %and.i, %vshl_n
+  ret <vscale x 8 x i16> %result
+}
+
+define <vscale x 4 x i32> @testLeftGood4x32(<vscale x 4 x i32> %src1, <vscale x 4 x i32> %src2) {
+; SVE-LABEL: testLeftGood4x32:
+; SVE:       // %bb.0:
+; SVE-NEXT:    and z0.s, z0.s, #0x3fffff
+; SVE-NEXT:    lsl z1.s, z1.s, #22
+; SVE-NEXT:    orr z0.d, z0.d, z1.d
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: testLeftGood4x32:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    sli z0.s, z1.s, #22
+; SVE2-NEXT:    ret
+  %and.i = and <vscale x 4 x i32> %src1, splat(i32 4194303)
+  %vshl_n = shl <vscale x 4 x i32> %src2, splat(i32 22)
+  %result = or <vscale x 4 x i32> %and.i, %vshl_n
+  ret <vscale x 4 x i32> %result
+}
+
+define <vscale x 4 x i32> @testLeftBad4x32(<vscale x 4 x i32> %src1, <vscale x 4 x i32> %src2) {
+; CHECK-LABEL: testLeftBad4x32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    and z0.s, z0.s, #0x3ffffc
+; CHECK-NEXT:    lsl z1.s, z1.s, #22
+; CHECK-NEXT:    orr z0.d, z0.d, z1.d
+; CHECK-NEXT:    ret
+  %and.i = and <vscale x 4 x i32> %src1, splat(i32 4194300)
+  %vshl_n = shl <vscale x 4 x i32> %src2, splat(i32 22)
+  %result = or <vscale x 4 x i32> %and.i, %vshl_n
+  ret <vscale x 4 x i32> %result
+}
+
+define <vscale x 4 x i32> @testRightGood4x32(<vscale x 4 x i32> %src1, <vscale x 4 x i32> %src2) {
+; SVE-LABEL: testRightGood4x32:
+; SVE:       // %bb.0:
+; SVE-NEXT:    and z0.s, z0.s, #0xfffffc00
+; SVE-NEXT:    lsr z1.s, z1.s, #22
+; SVE-NEXT:    orr z0.d, z0.d, z1.d
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: testRightGood4x32:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    sri z0.s, z1.s, #22
+; SVE2-NEXT:    ret
+  %and.i = and <vscale x 4 x i32> %src1, splat(i32 4294966272)
+  %vshl_n = lshr <vscale x 4 x i32> %src2, splat(i32 22)
+  %result = or <vscale x 4 x i32> %and.i, %vshl_n
+  ret <vscale x 4 x i32> %result
+}
+
+define <vscale x 4 x i32> @testRightBad4x32(<vscale x 4 x i32> %src1, <vscale x 4 x i32> %src2) {
+; CHECK-LABEL: testRightBad4x32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    and z0.s, z0.s, #0x3ffffc
+; CHECK-NEXT:    lsr z1.s, z1.s, #22
+; CHECK-NEXT:    orr z0.d, z0.d, z1.d
+; CHECK-NEXT:    ret
+  %and.i = and <vscale x 4 x i32> %src1, splat(i32 4194300)
+  %vshl_n = lshr <vscale x 4 x i32> %src2, splat(i32 22)
+  %result = or <vscale x 4 x i32> %and.i, %vshl_n
+  ret <vscale x 4 x i32> %result
+}
+
+define <vscale x 2 x i64> @testLeftGood2x64(<vscale x 2 x i64> %src1, <vscale x 2 x i64> %src2) {
+; SVE-LABEL: testLeftGood2x64:
+; SVE:       // %bb.0:
+; SVE-NEXT:    and z0.d, z0.d, #0xffffffffffff
+; SVE-NEXT:    lsl z1.d, z1.d, #48
+; SVE-NEXT:    orr z0.d, z0.d, z1.d
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: testLeftGood2x64:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    sli z0.d, z1.d, #48
+; SVE2-NEXT:    ret
+  %and.i = and <vscale x 2 x i64> %src1, splat(i64 281474976710655)
+  %vshl_n = shl <vscale x 2 x i64> %src2, splat(i64 48)
+  %result = or <vscale x 2 x i64> %and.i, %vshl_n
+  ret <vscale x 2 x i64> %result
+}
+
+define <vscale x 2 x i64> @testLeftBad2x64(<vscale x 2 x i64> %src1, <vscale x 2 x i64> %src2) {
+; CHECK-LABEL: testLeftBad2x64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov x8, #10 // =0xa
+; CHECK-NEXT:    lsl z1.d, z1.d, #48
+; CHECK-NEXT:    movk x8, #1, lsl #48
+; CHECK-NEXT:    mov z2.d, x8
+; CHECK-NEXT:    and z0.d, z0.d, z2.d
+; CHECK-NEXT:    orr z0.d, z0.d, z1.d
+; CHECK-NEXT:    ret
+  %and.i = and <vscale x 2 x i64> %src1, splat(i64 281474976710666)
+  %vshl_n = shl <vscale x 2 x i64> %src2, splat(i64 48)
+  %result = or <vscale x 2 x i64> %and.i, %vshl_n
+  ret <vscale x 2 x i64> %result
+}
+
+define <vscale x 2 x i64> @testRightGood2x64(<vscale x 2 x i64> %src1, <vscale x 2 x i64> %src2) {
+; SVE-LABEL: testRightGood2x64:
+; SVE:       // %bb.0:
+; SVE-NEXT:    and z0.d, z0.d, #0xffffffffffff0000
+; SVE-NEXT:    lsr z1.d, z1.d, #48
+; SVE-NEXT:    orr z0.d, z0.d, z1.d
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: testRightGood2x64:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    sri z0.d, z1.d, #48
+; SVE2-NEXT:    ret
+  %and.i = and <vscale x 2 x i64> %src1, splat(i64 18446744073709486080)
+  %vshl_n = lshr <vscale x 2 x i64> %src2, splat(i64 48)
+  %result = or <vscale x 2 x i64> %and.i, %vshl_n
+  ret <vscale x 2 x i64> %result
+}
+
+define <vscale x 2 x i64> @testRightBad2x64(<vscale x 2 x i64> %src1, <vscale x 2 x i64> %src2) {
+; CHECK-LABEL: testRightBad2x64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov x8, #10 // =0xa
+; CHECK-NEXT:    lsr z1.d, z1.d, #48
+; CHECK-NEXT:    movk x8, #1, lsl #48
+; CHECK-NEXT:    mov z2.d, x8
+; CHECK-NEXT:    and z0.d, z0.d, z2.d
+; CHECK-NEXT:    orr z0.d, z0.d, z1.d
+; CHECK-NEXT:    ret
+  %and.i = and <vscale x 2 x i64> %src1, splat(i64 281474976710666)
+  %vshl_n = lshr <vscale x 2 x i64> %src2, splat(i64 48)
+  %result = or <vscale x 2 x i64> %and.i, %vshl_n
+  ret <vscale x 2 x i64> %result
+}

>From 241411204d5ac80046432078fa6675243b169b10 Mon Sep 17 00:00:00 2001
From: "Nadeem, Usman" <mnadeem at quicinc.com>
Date: Wed, 10 Jan 2024 12:04:05 -0800
Subject: [PATCH 2/2] fixup! [AArch64][SVE2] Lower OR to SLI/SRI

---
 .../lib/Target/AArch64/AArch64ISelLowering.cpp | 18 ++++++++----------
 1 file changed, 8 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 269dde004bea78..d3b6c86d5c3395 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -12662,38 +12662,36 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
   else
     return SDValue();
 
-  uint64_t C1;
+  APInt C1AsAPInt;
+  unsigned ElemSizeInBits = VT.getScalarSizeInBits();
   if (IsAnd) {
     // Is the and mask vector all constant?
-    APInt C;
-    if (!ISD::isConstantSplatVector(And.getOperand(1).getNode(), C))
+    if (!ISD::isConstantSplatVector(And.getOperand(1).getNode(), C1AsAPInt))
       return SDValue();
-    C1 = C.getZExtValue();
   } else {
     // Reconstruct the corresponding AND immediate from the two BICi immediates.
     ConstantSDNode *C1nodeImm = dyn_cast<ConstantSDNode>(And.getOperand(1));
     ConstantSDNode *C1nodeShift = dyn_cast<ConstantSDNode>(And.getOperand(2));
     assert(C1nodeImm && C1nodeShift);
-    C1 = ~(C1nodeImm->getZExtValue() << C1nodeShift->getZExtValue());
+    C1AsAPInt = ~(C1nodeImm->getAPIntValue() << C1nodeShift->getAPIntValue());
+    C1AsAPInt = C1AsAPInt.zextOrTrunc(ElemSizeInBits);
   }
 
   // Is C1 == ~(Ones(ElemSizeInBits) << C2) or
   // C1 == ~(Ones(ElemSizeInBits) >> C2), taking into account
   // how much one can shift elements of a particular size?
-  unsigned ElemSizeInBits = VT.getScalarSizeInBits();
   if (C2 > ElemSizeInBits)
     return SDValue();
 
-  APInt C1AsAPInt(ElemSizeInBits, C1);
   APInt RequiredC1 = IsShiftRight ? APInt::getHighBitsSet(ElemSizeInBits, C2)
                                   : APInt::getLowBitsSet(ElemSizeInBits, C2);
   if (C1AsAPInt != RequiredC1)
     return SDValue();
 
   SDValue X = And.getOperand(0);
-  SDValue Y = (ShiftHasPredOp) ? Shift.getOperand(1) : Shift.getOperand(0);
-  SDValue Imm = (ShiftHasPredOp) ? DAG.getTargetConstant(C2, DL, MVT::i32)
-                                 : Shift.getOperand(1);
+  SDValue Y = ShiftHasPredOp ? Shift.getOperand(1) : Shift.getOperand(0);
+  SDValue Imm = ShiftHasPredOp ? DAG.getTargetConstant(C2, DL, MVT::i32)
+                               : Shift.getOperand(1);
 
   unsigned Inst = IsShiftRight ? AArch64ISD::VSRI : AArch64ISD::VSLI;
   SDValue ResultSLI = DAG.getNode(Inst, DL, VT, X, Y, Imm);



More information about the llvm-commits mailing list