[llvm] 67d4c99 - Add support for (expressing) vscale.

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 22 02:10:22 PST 2020


Author: Sander de Smalen
Date: 2020-01-22T10:09:27Z
New Revision: 67d4c9924c1fbfdbfcfa90bf729945eca0a92f86

URL: https://github.com/llvm/llvm-project/commit/67d4c9924c1fbfdbfcfa90bf729945eca0a92f86
DIFF: https://github.com/llvm/llvm-project/commit/67d4c9924c1fbfdbfcfa90bf729945eca0a92f86.diff

LOG: Add support for (expressing) vscale.

In LLVM IR, vscale can be represented with an intrinsic. For some targets,
this is equivalent to the constexpr:

  getelementptr <vscale x 1 x i8>, <vscale x 1 x i8>* null, i32 1

This can be used to propagate the value in CodeGenPrepare.

In ISel we add a node that can be legalized to one or more
instructions to materialize the runtime vector length.

This patch also adds SVE CodeGen support for VSCALE, which maps this
node to RDVL instructions (for scaled multiples of 16bytes) or CNT[HSD]
instructions (scaled multiples of 2, 4, or 8 bytes, respectively).

Reviewers: rengolin, cameron.mcinally, hfinkel, sebpop, SjoerdMeijer, efriedma, lattner

Reviewed by: efriedma

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D68203

Added: 
    llvm/test/CodeGen/AArch64/sve-vscale.ll

Modified: 
    llvm/docs/LangRef.rst
    llvm/include/llvm/CodeGen/ISDOpcodes.h
    llvm/include/llvm/CodeGen/SelectionDAG.h
    llvm/include/llvm/IR/Intrinsics.td
    llvm/include/llvm/IR/PatternMatch.h
    llvm/include/llvm/Target/TargetSelectionDAG.td
    llvm/lib/Analysis/ConstantFolding.cpp
    llvm/lib/CodeGen/CodeGenPrepare.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
    llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index e3c2a0afcb0e..50348f7bb57a 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -17889,6 +17889,34 @@ information on the *based on* terminology see
 mask argument does not match the pointer size of the target, the mask is
 zero-extended or truncated accordingly.
 
+.. _int_vscale:
+
+'``llvm.vscale``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+::
+
+      declare i32 llvm.vscale.i32()
+      declare i64 llvm.vscale.i64()
+
+Overview:
+"""""""""
+
+The ``llvm.vscale`` intrinsic returns the value for ``vscale`` in scalable
+vectors such as ``<vscale x 16 x i8>``.
+
+Semantics:
+""""""""""
+
+``vscale`` is a positive value that is constant throughout program
+execution, but is unknown at compile time.
+If the result value does not fit in the result type, then the result is
+a :ref:`poison value <poisonvalues>`.
+
+
 Stack Map Intrinsics
 --------------------
 

diff  --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 06140fae8790..2e447cf96159 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -921,6 +921,11 @@ namespace ISD {
     /// known nonzero constant. The only operand here is the chain.
     GET_DYNAMIC_AREA_OFFSET,
 
+    /// VSCALE(IMM) - Returns the runtime scaling factor used to calculate the
+    /// number of elements within a scalable vector. IMM is a constant integer
+    /// multiplier that is applied to the runtime value.
+    VSCALE,
+
     /// Generic reduction nodes. These nodes represent horizontal vector
     /// reduction operations, producing a scalar result.
     /// The STRICT variants perform reductions in sequential order. The first

diff  --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 1c1a365d2376..e848e33b3a3c 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -914,6 +914,13 @@ class SelectionDAG {
     return getNode(ISD::UNDEF, SDLoc(), VT);
   }
 
+  /// Return a node that represents the runtime scaling 'MulImm * RuntimeVL'.
+  SDValue getVScale(const SDLoc &DL, EVT VT, APInt MulImm) {
+    assert(MulImm.getMinSignedBits() <= VT.getSizeInBits() &&
+           "Immediate does not fit VT");
+    return getNode(ISD::VSCALE, DL, VT, getConstant(MulImm, DL, VT));
+  }
+
   /// Return a GLOBAL_OFFSET_TABLE node. This does not have a useful SDLoc.
   SDValue getGLOBAL_OFFSET_TABLE(EVT VT) {
     return getNode(ISD::GLOBAL_OFFSET_TABLE, SDLoc(), VT);

diff  --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index aa9f5adc4aba..d549eb27b4ba 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1342,6 +1342,11 @@ def int_preserve_struct_access_index : Intrinsic<[llvm_anyptr_ty],
                                                  [IntrNoMem, ImmArg<1>,
                                                   ImmArg<2>]>;
 
+//===---------- Intrinsics to query properties of scalable vectors --------===//
+def int_vscale : Intrinsic<[llvm_anyint_ty], [], [IntrNoMem]>;
+
+//===----------------------------------------------------------------------===//
+
 //===----------------------------------------------------------------------===//
 // Target-specific intrinsics
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 6621fc9f819c..bc22614c1373 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -32,6 +32,7 @@
 #include "llvm/ADT/APInt.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
 #include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
@@ -2002,6 +2003,42 @@ inline ExtractValue_match<Ind, Val_t> m_ExtractValue(const Val_t &V) {
   return ExtractValue_match<Ind, Val_t>(V);
 }
 
+/// Matches patterns for `vscale`. This can either be a call to `llvm.vscale` or
+/// the constant expression
+///  `ptrtoint(gep <vscale x 1 x i8>, <vscale x 1 x i8>* null, i32 1>`
+/// under the right conditions determined by DataLayout.
+struct VScaleVal_match {
+private:
+  template <typename Base, typename Offset>
+  inline BinaryOp_match<Base, Offset, Instruction::GetElementPtr>
+  m_OffsetGep(const Base &B, const Offset &O) {
+    return BinaryOp_match<Base, Offset, Instruction::GetElementPtr>(B, O);
+  }
+
+public:
+  const DataLayout &DL;
+  VScaleVal_match(const DataLayout &DL) : DL(DL) {}
+
+  template <typename ITy> bool match(ITy *V) {
+    if (m_Intrinsic<Intrinsic::vscale>().match(V))
+      return true;
+
+    if (m_PtrToInt(m_OffsetGep(m_Zero(), m_SpecificInt(1))).match(V)) {
+      Type *PtrTy = cast<Operator>(V)->getOperand(0)->getType();
+      Type *DerefTy = PtrTy->getPointerElementType();
+      if (DerefTy->isVectorTy() && DerefTy->getVectorIsScalable() &&
+          DL.getTypeAllocSizeInBits(DerefTy).getKnownMinSize() == 8)
+        return true;
+    }
+
+    return false;
+  }
+};
+
+inline VScaleVal_match m_VScale(const DataLayout &DL) {
+  return VScaleVal_match(DL);
+}
+
 } // end namespace PatternMatch
 } // end namespace llvm
 

diff  --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 6ee711716463..1f3ea152f239 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -316,6 +316,7 @@ def vt         : SDNode<"ISD::VALUETYPE" , SDTOther   , [], "VTSDNode">;
 def bb         : SDNode<"ISD::BasicBlock", SDTOther   , [], "BasicBlockSDNode">;
 def cond       : SDNode<"ISD::CONDCODE"  , SDTOther   , [], "CondCodeSDNode">;
 def undef      : SDNode<"ISD::UNDEF"     , SDTUNDEF   , []>;
+def vscale     : SDNode<"ISD::VSCALE"    , SDTIntUnaryOp, []>;
 def globaladdr : SDNode<"ISD::GlobalAddress",         SDTPtrLeaf, [],
                         "GlobalAddressSDNode">;
 def tglobaladdr : SDNode<"ISD::TargetGlobalAddress",  SDTPtrLeaf, [],

diff  --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index b32924e6497a..280a13ff1153 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -828,7 +828,8 @@ Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP,
   Type *SrcElemTy = GEP->getSourceElementType();
   Type *ResElemTy = GEP->getResultElementType();
   Type *ResTy = GEP->getType();
-  if (!SrcElemTy->isSized())
+  if (!SrcElemTy->isSized() ||
+      (SrcElemTy->isVectorTy() && SrcElemTy->getVectorIsScalable()))
     return nullptr;
 
   if (Constant *C = CastGEPIndices(SrcElemTy, Ops, ResTy,

diff  --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index f05afd058746..01adfc197041 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -2010,6 +2010,22 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, bool &ModifiedDT) {
       return despeculateCountZeros(II, TLI, DL, ModifiedDT);
     case Intrinsic::dbg_value:
       return fixupDbgValue(II);
+    case Intrinsic::vscale: {
+      // If datalayout has no special restrictions on vector data layout,
+      // replace `llvm.vscale` by an equivalent constant expression
+      // to benefit from cheap constant propagation.
+      Type *ScalableVectorTy =
+          VectorType::get(Type::getInt8Ty(II->getContext()), 1, true);
+      if (DL->getTypeAllocSize(ScalableVectorTy).getKnownMinSize() == 8) {
+        auto Null = Constant::getNullValue(ScalableVectorTy->getPointerTo());
+        auto One = ConstantInt::getSigned(II->getType(), 1);
+        auto *CGep =
+            ConstantExpr::getGetElementPtr(ScalableVectorTy, Null, One);
+        II->replaceAllUsesWith(ConstantExpr::getPtrToInt(CGep, II->getType()));
+        II->eraseFromParent();
+        return true;
+      }
+    }
     }
 
     if (TLI) {

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 3458bf084169..89edddb42be1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -91,6 +91,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::TRUNCATE:    Res = PromoteIntRes_TRUNCATE(N); break;
   case ISD::UNDEF:       Res = PromoteIntRes_UNDEF(N); break;
   case ISD::VAARG:       Res = PromoteIntRes_VAARG(N); break;
+  case ISD::VSCALE:      Res = PromoteIntRes_VSCALE(N); break;
 
   case ISD::EXTRACT_SUBVECTOR:
                          Res = PromoteIntRes_EXTRACT_SUBVECTOR(N); break;
@@ -1179,6 +1180,13 @@ SDValue DAGTypeLegalizer::PromoteIntRes_UNDEF(SDNode *N) {
                                                N->getValueType(0)));
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_VSCALE(SDNode *N) {
+  EVT VT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+
+  APInt MulImm = cast<ConstantSDNode>(N->getOperand(0))->getAPIntValue();
+  return DAG.getVScale(SDLoc(N), VT, MulImm.sextOrSelf(VT.getSizeInBits()));
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_VAARG(SDNode *N) {
   SDValue Chain = N->getOperand(0); // Get the chain.
   SDValue Ptr = N->getOperand(1); // Get the pointer.

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index faae14444d51..10351ca5fcbf 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -326,6 +326,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_ADDSUBCARRY(SDNode *N, unsigned ResNo);
   SDValue PromoteIntRes_UNDEF(SDNode *N);
   SDValue PromoteIntRes_VAARG(SDNode *N);
+  SDValue PromoteIntRes_VSCALE(SDNode *N);
   SDValue PromoteIntRes_XMULO(SDNode *N, unsigned ResNo);
   SDValue PromoteIntRes_ADDSUBSAT(SDNode *N);
   SDValue PromoteIntRes_MULFIX(SDNode *N);

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 6c793afefe0d..94e5011576cd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -5185,11 +5185,20 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     if (N2C && N2C->isNullValue())
       return N1;
     break;
+  case ISD::MUL:
+    assert(VT.isInteger() && "This operator does not apply to FP types!");
+    assert(N1.getValueType() == N2.getValueType() &&
+           N1.getValueType() == VT && "Binary operator types must match!");
+    if (N2C && (N1.getOpcode() == ISD::VSCALE) && Flags.hasNoSignedWrap()) {
+      APInt MulImm = cast<ConstantSDNode>(N1->getOperand(0))->getAPIntValue();
+      APInt N2CImm = N2C->getAPIntValue();
+      return getVScale(DL, VT, MulImm * N2CImm);
+    }
+    break;
   case ISD::UDIV:
   case ISD::UREM:
   case ISD::MULHU:
   case ISD::MULHS:
-  case ISD::MUL:
   case ISD::SDIV:
   case ISD::SREM:
   case ISD::SMIN:
@@ -5222,6 +5231,12 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
            "Invalid FCOPYSIGN!");
     break;
   case ISD::SHL:
+    if (N2C && (N1.getOpcode() == ISD::VSCALE) && Flags.hasNoSignedWrap()) {
+      APInt MulImm = cast<ConstantSDNode>(N1->getOperand(0))->getAPIntValue();
+      APInt ShiftImm = N2C->getAPIntValue();
+      return getVScale(DL, VT, MulImm << ShiftImm);
+    }
+    LLVM_FALLTHROUGH;
   case ISD::SRA:
   case ISD::SRL:
     if (SDValue V = simplifyShift(N1, N2))

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index e43f79c0c030..59931b1ea5b5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -1482,6 +1482,9 @@ SDValue SelectionDAGBuilder::getValueImpl(const Value *V) {
                              TLI.getPointerTy(DAG.getDataLayout(), AS));
     }
 
+    if (match(C, m_VScale(DAG.getDataLayout())))
+      return DAG.getVScale(getCurSDLoc(), VT, APInt(VT.getSizeInBits(), 1));
+
     if (const ConstantFP *CFP = dyn_cast<ConstantFP>(C))
       return DAG.getConstantFP(*CFP, getCurSDLoc(), VT);
 
@@ -5772,6 +5775,13 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     // By default, turn this into a target intrinsic node.
     visitTargetIntrinsic(I, Intrinsic);
     return;
+  case Intrinsic::vscale: {
+    match(&I, m_VScale(DAG.getDataLayout()));
+    EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType());
+    setValue(&I,
+             DAG.getVScale(getCurSDLoc(), VT, APInt(VT.getSizeInBits(), 1)));
+    return;
+  }
   case Intrinsic::vastart:  visitVAStart(I); return;
   case Intrinsic::vaend:    visitVAEnd(I); return;
   case Intrinsic::vacopy:   visitVACopy(I); return;

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 6fd71393bf38..6a46fb41339d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -170,6 +170,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::CopyToReg:                  return "CopyToReg";
   case ISD::CopyFromReg:                return "CopyFromReg";
   case ISD::UNDEF:                      return "undef";
+  case ISD::VSCALE:                     return "vscale";
   case ISD::MERGE_VALUES:               return "merge_values";
   case ISD::INLINEASM:                  return "inlineasm";
   case ISD::INLINEASM_BR:               return "inlineasm_br";

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index a51aa85a931c..9d6cc3a24d8a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -62,6 +62,9 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
                                     unsigned ConstraintID,
                                     std::vector<SDValue> &OutOps) override;
 
+  template <signed Low, signed High, signed Scale>
+  bool SelectRDVLImm(SDValue N, SDValue &Imm);
+
   bool tryMLAV64LaneV128(SDNode *N);
   bool tryMULLV64LaneV128(unsigned IntNo, SDNode *N);
   bool SelectArithExtendedRegister(SDValue N, SDValue &Reg, SDValue &Shift);
@@ -679,6 +682,23 @@ static SDValue narrowIfNeeded(SelectionDAG *CurDAG, SDValue N) {
   return SDValue(Node, 0);
 }
 
+// Returns a suitable CNT/INC/DEC/RDVL multiplier to calculate VSCALE*N.
+template<signed Low, signed High, signed Scale>
+bool AArch64DAGToDAGISel::SelectRDVLImm(SDValue N, SDValue &Imm) {
+  if (!isa<ConstantSDNode>(N))
+    return false;
+
+  int64_t MulImm = cast<ConstantSDNode>(N)->getSExtValue();
+  if ((MulImm % std::abs(Scale)) == 0) {
+    int64_t RDVLImm = MulImm / Scale;
+    if ((RDVLImm >= Low) && (RDVLImm <= High)) {
+      Imm = CurDAG->getTargetConstant(RDVLImm, SDLoc(N), MVT::i32);
+      return true;
+    }
+  }
+
+  return false;
+}
 
 /// SelectArithExtendedRegister - Select a "extended register" operand.  This
 /// operand folds in an extend followed by an optional left shift.

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 0e871c229204..7cbc441bec7a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -836,6 +836,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       }
     }
 
+    if (Subtarget->hasSVE())
+      setOperationAction(ISD::VSCALE, MVT::i32, Custom);
+
     setTruncStoreAction(MVT::v4i16, MVT::v4i8, Custom);
   }
 
@@ -3254,6 +3257,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerATOMIC_LOAD_AND(Op, DAG);
   case ISD::DYNAMIC_STACKALLOC:
     return LowerDYNAMIC_STACKALLOC(Op, DAG);
+  case ISD::VSCALE:
+    return LowerVSCALE(Op, DAG);
   }
 }
 
@@ -8641,6 +8646,17 @@ AArch64TargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
   return DAG.getMergeValues(Ops, dl);
 }
 
+SDValue AArch64TargetLowering::LowerVSCALE(SDValue Op,
+                                           SelectionDAG &DAG) const {
+  EVT VT = Op.getValueType();
+  assert(VT != MVT::i64 && "Expected illegal VSCALE node");
+
+  SDLoc DL(Op);
+  APInt MulImm = cast<ConstantSDNode>(Op.getOperand(0))->getAPIntValue();
+  return DAG.getZExtOrTrunc(DAG.getVScale(DL, MVT::i64, MulImm.sextOrSelf(64)),
+                            DL, VT);
+}
+
 /// getTgtMemIntrinsic - Represent NEON load and store intrinsics as
 /// MemIntrinsicNodes.  The associated MachineMemOperands record the alignment
 /// specified in the intrinsic calls.

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 5aaeebef3088..09a4d324e535 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -748,6 +748,7 @@ class AArch64TargetLowering : public TargetLowering {
   SDValue LowerVectorOR(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFSINCOS(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerVSCALE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerATOMIC_LOAD_SUB(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerATOMIC_LOAD_AND(SDValue Op, SelectionDAG &DAG) const;

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 90d954d56539..bab89e7c5654 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -46,6 +46,17 @@ def AArch64ld1_gather_uxtw_scaled    : SDNode<"AArch64ISD::GLD1_UXTW_SCALED",
 def AArch64ld1_gather_sxtw_scaled    : SDNode<"AArch64ISD::GLD1_SXTW_SCALED",   SDT_AArch64_GLD1,     [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>;
 def AArch64ld1_gather_imm            : SDNode<"AArch64ISD::GLD1_IMM",           SDT_AArch64_GLD1_IMM, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>;
 
+// SVE CNT/INC/RDVL
+def sve_rdvl_imm : ComplexPattern<i32, 1, "SelectRDVLImm<-32, 31, 16>">;
+def sve_cnth_imm : ComplexPattern<i32, 1, "SelectRDVLImm<1, 16, 8>">;
+def sve_cntw_imm : ComplexPattern<i32, 1, "SelectRDVLImm<1, 16, 4>">;
+def sve_cntd_imm : ComplexPattern<i32, 1, "SelectRDVLImm<1, 16, 2>">;
+
+// SVE DEC
+def sve_cnth_imm_neg : ComplexPattern<i32, 1, "SelectRDVLImm<1, 16, -8>">;
+def sve_cntw_imm_neg : ComplexPattern<i32, 1, "SelectRDVLImm<1, 16, -4>">;
+def sve_cntd_imm_neg : ComplexPattern<i32, 1, "SelectRDVLImm<1, 16, -2>">;
+
 def AArch64ld1s_gather               : SDNode<"AArch64ISD::GLD1S",              SDT_AArch64_GLD1,     [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>;
 def AArch64ld1s_gather_scaled        : SDNode<"AArch64ISD::GLD1S_SCALED",       SDT_AArch64_GLD1,     [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>;
 def AArch64ld1s_gather_uxtw          : SDNode<"AArch64ISD::GLD1S_UXTW",         SDT_AArch64_GLD1,     [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>;
@@ -1105,6 +1116,23 @@ let Predicates = [HasSVE] in {
   def : Pat<(sext_inreg (nxv4i32 ZPR:$Zs), nxv4i8),  (SXTB_ZPmZ_S (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>;
   def : Pat<(sext_inreg (nxv8i16 ZPR:$Zs), nxv8i8),  (SXTB_ZPmZ_H (IMPLICIT_DEF), (PTRUE_H 31), ZPR:$Zs)>;
 
+  // General case that we ideally never want to match.
+  def : Pat<(vscale GPR64:$scale), (MADDXrrr (UBFMXri (RDVLI_XI 1), 4, 63), $scale, XZR)>;
+
+  let AddedComplexity = 5 in {
+    def : Pat<(vscale (i64 1)), (UBFMXri (RDVLI_XI 1), 4, 63)>;
+    def : Pat<(vscale (i64 -1)), (SBFMXri (RDVLI_XI -1), 4, 63)>;
+
+    def : Pat<(vscale (sve_rdvl_imm i32:$imm)), (RDVLI_XI $imm)>;
+    def : Pat<(vscale (sve_cnth_imm i32:$imm)), (CNTH_XPiI 31, $imm)>;
+    def : Pat<(vscale (sve_cntw_imm i32:$imm)), (CNTW_XPiI 31, $imm)>;
+    def : Pat<(vscale (sve_cntd_imm i32:$imm)), (CNTD_XPiI 31, $imm)>;
+
+    def : Pat<(vscale (sve_cnth_imm_neg i32:$imm)), (SUBXrs XZR, (CNTH_XPiI 31, $imm), 0)>;
+    def : Pat<(vscale (sve_cntw_imm_neg i32:$imm)), (SUBXrs XZR, (CNTW_XPiI 31, $imm), 0)>;
+    def : Pat<(vscale (sve_cntd_imm_neg i32:$imm)), (SUBXrs XZR, (CNTD_XPiI 31, $imm), 0)>;
+  }
+
   def : Pat<(nxv16i8 (bitconvert (nxv8i16 ZPR:$src))), (nxv16i8 ZPR:$src)>;
   def : Pat<(nxv16i8 (bitconvert (nxv4i32 ZPR:$src))), (nxv16i8 ZPR:$src)>;
   def : Pat<(nxv16i8 (bitconvert (nxv2i64 ZPR:$src))), (nxv16i8 ZPR:$src)>;

diff  --git a/llvm/test/CodeGen/AArch64/sve-vscale.ll b/llvm/test/CodeGen/AArch64/sve-vscale.ll
new file mode 100644
index 000000000000..d156b1997e59
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-vscale.ll
@@ -0,0 +1,200 @@
+; RUN: llc -mtriple aarch64 -mattr=+sve -asm-verbose=0 < %s | FileCheck %s
+; RUN: opt -codegenprepare -S < %s | llc -mtriple aarch64 -mattr=+sve -asm-verbose=0 | FileCheck %s
+
+;
+; RDVL
+;
+
+; CHECK-LABEL: rdvl_i8:
+; CHECK:       rdvl x0, #1
+; CHECK-NEXT:  ret
+define i8 @rdvl_i8() nounwind {
+  %vscale = call i8 @llvm.vscale.i8()
+  %1 = mul nsw i8 %vscale, 16
+  ret i8 %1
+}
+
+; CHECK-LABEL: rdvl_i16:
+; CHECK:       rdvl x0, #1
+; CHECK-NEXT:  ret
+define i16 @rdvl_i16() nounwind {
+  %vscale = call i16 @llvm.vscale.i16()
+  %1 = mul nsw i16 %vscale, 16
+  ret i16 %1
+}
+
+; CHECK-LABEL: rdvl_i32:
+; CHECK:       rdvl x0, #1
+; CHECK-NEXT:  ret
+define i32 @rdvl_i32() nounwind {
+  %vscale = call i32 @llvm.vscale.i32()
+  %1 = mul nsw i32 %vscale, 16
+  ret i32 %1
+}
+
+; CHECK-LABEL: rdvl_i64:
+; CHECK:       rdvl x0, #1
+; CHECK-NEXT:  ret
+define i64 @rdvl_i64() nounwind {
+  %vscale = call i64 @llvm.vscale.i64()
+  %1 = mul nsw i64 %vscale, 16
+  ret i64 %1
+}
+
+; CHECK-LABEL: rdvl_const:
+; CHECK:       rdvl x0, #1
+; CHECK-NEXT:  ret
+define i32 @rdvl_const() nounwind {
+  ret i32 mul nsw (i32 ptrtoint (<vscale x 1 x i8>* getelementptr (<vscale x 1 x i8>, <vscale x 1 x i8>* null, i64 1) to i32), i32 16)
+}
+
+define i32 @vscale_1() nounwind {
+; CHECK-LABEL: vscale_1:
+; CHECK:       rdvl [[TMP:x[0-9]+]], #1
+; CHECK-NEXT:  lsr  x0, [[TMP]], #4
+; CHECK-NEXT:  ret
+  %vscale = call i32 @llvm.vscale.i32()
+  ret i32 %vscale
+}
+
+define i32 @vscale_neg1() nounwind {
+; CHECK-LABEL: vscale_neg1:
+; CHECK:       rdvl [[TMP:x[0-9]+]], #-1
+; CHECK-NEXT:  asr  x0, [[TMP]], #4
+; CHECK-NEXT:  ret
+  %vscale = call i32 @llvm.vscale.i32()
+  %neg = mul nsw i32 -1, %vscale
+  ret i32 %neg
+}
+
+; CHECK-LABEL: rdvl_3:
+; CHECK:       rdvl [[VL_B:x[0-9]+]], #1
+; CHECK-NEXT:  lsr  [[VL_Q:x[0-9]+]], [[VL_B]], #4
+; CHECK-NEXT:  mov  w[[MUL:[0-9]+]], #3
+; CHECK-NEXT:  mul  x0, [[VL_Q]], x[[MUL]]
+; CHECK-NEXT:  ret
+define i32 @rdvl_3() nounwind {
+  %vscale = call i32 @llvm.vscale.i32()
+  %1 = mul nsw i32 %vscale, 3
+  ret i32 %1
+}
+
+
+; CHECK-LABEL: rdvl_min:
+; CHECK:       rdvl x0, #-32
+; CHECK-NEXT:  ret
+define i32 @rdvl_min() nounwind {
+  %vscale = call i32 @llvm.vscale.i32()
+  %1 = mul nsw i32 %vscale, -512
+  ret i32 %1
+}
+
+; CHECK-LABEL: rdvl_max:
+; CHECK:       rdvl x0, #31
+; CHECK-NEXT:  ret
+define i32 @rdvl_max() nounwind {
+  %vscale = call i32 @llvm.vscale.i32()
+  %1 = mul nsw i32 %vscale, 496
+  ret i32 %1
+}
+
+;
+; CNTH
+;
+
+; CHECK-LABEL: cnth:
+; CHECK:       cnth x0{{$}}
+; CHECK-NEXT:  ret
+define i32 @cnth() nounwind {
+  %vscale = call i32 @llvm.vscale.i32()
+  %1 = shl nsw i32 %vscale, 3
+  ret i32 %1
+}
+
+; CHECK-LABEL: cnth_max:
+; CHECK:       cnth x0, all, mul #15
+; CHECK-NEXT:  ret
+define i32 @cnth_max() nounwind {
+  %vscale = call i32 @llvm.vscale.i32()
+  %1 = mul nsw i32 %vscale, 120
+  ret i32 %1
+}
+
+; CHECK-LABEL: cnth_neg:
+; CHECK:       cnth [[CNT:x[0-9]+]]
+; CHECK:       neg x0, [[CNT]]
+; CHECK-NEXT:  ret
+define i32 @cnth_neg() nounwind {
+  %vscale = call i32 @llvm.vscale.i32()
+  %1 = mul nsw i32 %vscale, -8
+  ret i32 %1
+}
+
+;
+; CNTW
+;
+
+; CHECK-LABEL: cntw:
+; CHECK:       cntw x0{{$}}
+; CHECK-NEXT:  ret
+define i32 @cntw() nounwind {
+  %vscale = call i32 @llvm.vscale.i32()
+  %1 = shl nsw i32 %vscale, 2
+  ret i32 %1
+}
+
+; CHECK-LABEL: cntw_max:
+; CHECK:       cntw x0, all, mul #15
+; CHECK-NEXT:  ret
+define i32 @cntw_max() nounwind {
+  %vscale = call i32 @llvm.vscale.i32()
+  %1 = mul nsw i32 %vscale, 60
+  ret i32 %1
+}
+
+; CHECK-LABEL: cntw_neg:
+; CHECK:       cntw [[CNT:x[0-9]+]]
+; CHECK:       neg x0, [[CNT]]
+; CHECK-NEXT:  ret
+define i32 @cntw_neg() nounwind {
+  %vscale = call i32 @llvm.vscale.i32()
+  %1 = mul nsw i32 %vscale, -4
+  ret i32 %1
+}
+
+;
+; CNTD
+;
+
+; CHECK-LABEL: cntd:
+; CHECK:       cntd x0{{$}}
+; CHECK-NEXT:  ret
+define i32 @cntd() nounwind {
+  %vscale = call i32 @llvm.vscale.i32()
+  %1 = shl nsw i32 %vscale, 1
+  ret i32 %1
+}
+
+; CHECK-LABEL: cntd_max:
+; CHECK:       cntd x0, all, mul #15
+; CHECK-NEXT:  ret
+define i32 @cntd_max() nounwind {
+  %vscale = call i32 @llvm.vscale.i32()
+  %1 = mul nsw i32 %vscale, 30
+  ret i32 %1
+}
+
+; CHECK-LABEL: cntd_neg:
+; CHECK:       cntd [[CNT:x[0-9]+]]
+; CHECK:       neg x0, [[CNT]]
+; CHECK-NEXT:  ret
+define i32 @cntd_neg() nounwind {
+  %vscale = call i32 @llvm.vscale.i32()
+  %1 = mul nsw i32 %vscale, -2
+  ret i32 %1
+}
+
+declare i8 @llvm.vscale.i8()
+declare i16 @llvm.vscale.i16()
+declare i32 @llvm.vscale.i32()
+declare i64 @llvm.vscale.i64()


        


More information about the llvm-commits mailing list