[clang] [llvm] [LLVM][IR] Add native vector support to ConstantInt & ConstantFP. (PR #74502)

Paul Walker via cfe-commits cfe-commits at lists.llvm.org
Tue Dec 5 09:42:06 PST 2023


https://github.com/paulwalker-arm created https://github.com/llvm/llvm-project/pull/74502

[LLVM][IR] Add native vector support to ConstantInt & ConstantFP.

NOTE: For brevity the following takes about ConstantInt but
everything extends to cover ConstantFP as well.

Whilst ConstantInt::get() supports the creation of vectors whereby
each lane has the same value, it achieves this via other constants:

  * ConstantVector for fixed-length vectors
  * ConstantExprs for scalable vectors
    
ConstantExprs are being deprecated and ConstantVector is not space
efficient for larger vector types. This patch introduces an
alternative by allowing ConstantInt to natively support vector
splats via the IR syntax:

  <N x ty> splat(ty <imm>)

More specifically:

 * IR parsing is extended to support the new syntax.
 * ConstantInt gains the interface getSplat().
 * LLVMContext is extended to map <EC,APInt>->ConstantInt.
 * BitCodeReader/Writer is extended to support vector types.
    
Whilst this patch adds the base support, more work is required
before it's production ready. For example, there's likely to be
many places where isa<ConstantInt> assumes a scalar type. Accordingly
the default behaviour of ConstantInt::get() remains unchanged but a
set of flag are added to allow wider testing and thus help with the
migration:
    
  --use-constant-int-for-fixed-length-splat
  --use-constant-fp-for-fixed-length-splat
  --use-constant-int-for-scalable-splat
  --use-constant-fp-for-scalable-splat

NOTE: No change is required to the bitcode format because types and
values are handled separately.

NOTE: Code generation doesn't work out-the-box but the issues look
limited to calls to ConstantInt::getBitWidth() that will need to be
ported.

>From 4c999f2e134ffc0385ec18ecbf1a80a696b7d095 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Wed, 29 Nov 2023 14:45:06 +0000
Subject: [PATCH 1/3] [NFC][LLVM][IR] Rename ConstantInt's getType() to
 getIntegerType().

Also adds an assert to ConstantInt::getBitWidth() to ensure it's
only called for integer types. This will have no affect today but
will aid with problem solving when ConstantInt is extended to
support vector types.
---
 clang/lib/CodeGen/CGBuiltin.cpp                   |  7 ++++---
 llvm/include/llvm/IR/Constants.h                  | 12 +++++++++++-
 llvm/lib/Analysis/InstructionSimplify.cpp         |  2 +-
 llvm/lib/IR/ConstantFold.cpp                      |  2 +-
 llvm/lib/IR/Verifier.cpp                          |  4 ++--
 .../Hexagon/HexagonLoopIdiomRecognition.cpp       |  6 +++---
 llvm/lib/Transforms/IPO/OpenMPOpt.cpp             | 15 ++++++++-------
 .../InstCombine/InstCombineVectorOps.cpp          |  4 ++--
 llvm/lib/Transforms/Scalar/ConstantHoisting.cpp   |  6 +++---
 llvm/lib/Transforms/Scalar/LoopFlatten.cpp        |  2 +-
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp         |  8 ++++----
 11 files changed, 40 insertions(+), 28 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 65d9862621061..8dc828abf8aec 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -3218,7 +3218,7 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
     Value *AlignmentValue = EmitScalarExpr(E->getArg(1));
     ConstantInt *AlignmentCI = cast<ConstantInt>(AlignmentValue);
     if (AlignmentCI->getValue().ugt(llvm::Value::MaximumAlignment))
-      AlignmentCI = ConstantInt::get(AlignmentCI->getType(),
+      AlignmentCI = ConstantInt::get(AlignmentCI->getIntegerType(),
                                      llvm::Value::MaximumAlignment);
 
     emitAlignmentAssumption(PtrValue, Ptr,
@@ -17010,7 +17010,7 @@ Value *CodeGenFunction::EmitPPCBuiltinExpr(unsigned BuiltinID,
     Value *Op1 = EmitScalarExpr(E->getArg(1));
     ConstantInt *AlignmentCI = cast<ConstantInt>(Op0);
     if (AlignmentCI->getValue().ugt(llvm::Value::MaximumAlignment))
-      AlignmentCI = ConstantInt::get(AlignmentCI->getType(),
+      AlignmentCI = ConstantInt::get(AlignmentCI->getIntegerType(),
                                      llvm::Value::MaximumAlignment);
 
     emitAlignmentAssumption(Op1, E->getArg(1),
@@ -17248,7 +17248,8 @@ Value *CodeGenFunction::EmitPPCBuiltinExpr(unsigned BuiltinID,
         Op0, llvm::FixedVectorType::get(ConvertType(E->getType()), 2));
 
     if (getTarget().isLittleEndian())
-      Index = ConstantInt::get(Index->getType(), 1 - Index->getZExtValue());
+      Index =
+          ConstantInt::get(Index->getIntegerType(), 1 - Index->getZExtValue());
 
     return Builder.CreateExtractElement(Unpacked, Index);
   }
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 2f7fc5652c2cd..7bd8bfc477d78 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -136,7 +136,11 @@ class ConstantInt final : public ConstantData {
   inline const APInt &getValue() const { return Val; }
 
   /// getBitWidth - Return the bitwidth of this constant.
-  unsigned getBitWidth() const { return Val.getBitWidth(); }
+  unsigned getBitWidth() const {
+    assert(Value::getType()->isIntegerTy() &&
+           "Returning the bitwidth of a vector constant is not support!");
+    return Val.getBitWidth();
+  }
 
   /// Return the constant as a 64-bit unsigned integer value after it
   /// has been zero extended as appropriate for the type of this constant. Note
@@ -177,6 +181,12 @@ class ConstantInt final : public ConstantData {
     return cast<IntegerType>(Value::getType());
   }
 
+  /// Variant of the getType() method to always return an IntegerType, which
+  /// reduces the amount of casting needed in parts of the compiler.
+  inline IntegerType *getIntegerType() const {
+    return cast<IntegerType>(Value::getType());
+  }
+
   /// This static method returns true if the type Ty is big enough to
   /// represent the value V. This can be used to avoid having the get method
   /// assert when V is larger than Ty can represent. Note that there are two
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index cef9f6ec179ba..c24bb1bb2cf9f 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -6081,7 +6081,7 @@ static Value *simplifyRelativeLoad(Constant *Ptr, Constant *Offset,
   Type *Int32Ty = Type::getInt32Ty(Ptr->getContext());
 
   auto *OffsetConstInt = dyn_cast<ConstantInt>(Offset);
-  if (!OffsetConstInt || OffsetConstInt->getType()->getBitWidth() > 64)
+  if (!OffsetConstInt || OffsetConstInt->getIntegerType()->getBitWidth() > 64)
     return nullptr;
 
   APInt OffsetInt = OffsetConstInt->getValue().sextOrTrunc(
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index d499d74f7ba01..c478040234078 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -868,7 +868,7 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
           }
 
           if (GVAlign > 1) {
-            unsigned DstWidth = CI2->getType()->getBitWidth();
+            unsigned DstWidth = CI2->getIntegerType()->getBitWidth();
             unsigned SrcWidth = std::min(DstWidth, Log2(GVAlign));
             APInt BitsNotSet(APInt::getLowBitsSet(DstWidth, SrcWidth));
 
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 5f466581ea980..5d38d6e5572c3 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2280,7 +2280,7 @@ void Verifier::verifyFunctionMetadata(
       Constant *C = cast<ConstantAsMetadata>(MD->getOperand(0))->getValue();
       Check(isa<ConstantInt>(C),
             "expected a constant integer operand for !kcfi_type", MD);
-      IntegerType *Type = cast<ConstantInt>(C)->getType();
+      IntegerType *Type = cast<ConstantInt>(C)->getIntegerType();
       Check(Type->getBitWidth() == 32,
             "expected a 32-bit integer constant operand for !kcfi_type", MD);
     }
@@ -5672,7 +5672,7 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
           "vector of ints");
 
     auto *Op3 = cast<ConstantInt>(Call.getArgOperand(2));
-    Check(Op3->getType()->getBitWidth() <= 32,
+    Check(Op3->getIntegerType()->getBitWidth() <= 32,
           "third argument of [us][mul|div]_fix[_sat] must fit within 32 bits");
 
     if (ID == Intrinsic::smul_fix || ID == Intrinsic::smul_fix_sat ||
diff --git a/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp b/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
index 51ef72b873a51..fc802c309540f 100644
--- a/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
@@ -1062,7 +1062,7 @@ void PolynomialMultiplyRecognize::promoteTo(Instruction *In,
   // Promote immediates.
   for (unsigned i = 0, n = In->getNumOperands(); i != n; ++i) {
     if (ConstantInt *CI = dyn_cast<ConstantInt>(In->getOperand(i)))
-      if (CI->getType()->getBitWidth() < DestBW)
+      if (CI->getIntegerType()->getBitWidth() < DestBW)
         In->setOperand(i, ConstantInt::get(DestTy, CI->getZExtValue()));
   }
 }
@@ -1577,7 +1577,7 @@ Value *PolynomialMultiplyRecognize::generate(BasicBlock::iterator At,
 
 static bool hasZeroSignBit(const Value *V) {
   if (const auto *CI = dyn_cast<const ConstantInt>(V))
-    return (CI->getType()->getSignBit() & CI->getSExtValue()) == 0;
+    return (CI->getIntegerType()->getSignBit() & CI->getSExtValue()) == 0;
   const Instruction *I = dyn_cast<const Instruction>(V);
   if (!I)
     return false;
@@ -1688,7 +1688,7 @@ void PolynomialMultiplyRecognize::setupPreSimplifier(Simplifier &S) {
       if (I->getOpcode() != Instruction::Or)
         return nullptr;
       ConstantInt *Msb = dyn_cast<ConstantInt>(I->getOperand(1));
-      if (!Msb || Msb->getZExtValue() != Msb->getType()->getSignBit())
+      if (!Msb || Msb->getZExtValue() != Msb->getIntegerType()->getSignBit())
         return nullptr;
       if (!hasZeroSignBit(I->getOperand(0)))
         return nullptr;
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index b2665161c090d..2c880316e0a1c 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -3763,7 +3763,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
     ConstantInt *ExecModeC =
         KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
     ConstantInt *AssumedExecModeC = ConstantInt::get(
-        ExecModeC->getType(),
+        ExecModeC->getIntegerType(),
         ExecModeC->getSExtValue() | OMP_TGT_EXEC_MODE_GENERIC_SPMD);
     if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
@@ -3792,7 +3792,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
     ConstantInt *MayUseNestedParallelismC =
         KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
     ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
-        MayUseNestedParallelismC->getType(), NestedParallelism);
+        MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
     setMayUseNestedParallelismOfKernelEnvironment(
         AssumedMayUseNestedParallelismC);
 
@@ -3801,7 +3801,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
           KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
               KernelEnvC);
       ConstantInt *AssumedUseGenericStateMachineC =
-          ConstantInt::get(UseGenericStateMachineC->getType(), false);
+          ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
       setUseGenericStateMachineOfKernelEnvironment(
           AssumedUseGenericStateMachineC);
     }
@@ -4280,8 +4280,9 @@ struct AAKernelInfoFunction : AAKernelInfo {
     // kernel is executed in.
     assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
            "Initially non-SPMD kernel has SPMD exec mode!");
-    setExecModeOfKernelEnvironment(ConstantInt::get(
-        ExecModeC->getType(), ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
+    setExecModeOfKernelEnvironment(
+        ConstantInt::get(ExecModeC->getIntegerType(),
+                         ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
 
     ++NumOpenMPTargetRegionKernelsSPMD;
 
@@ -4332,7 +4333,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
 
     // If not SPMD mode, indicate we use a custom state machine now.
     setUseGenericStateMachineOfKernelEnvironment(
-        ConstantInt::get(UseStateMachineC->getType(), false));
+        ConstantInt::get(UseStateMachineC->getIntegerType(), false));
 
     // If we don't actually need a state machine we are done here. This can
     // happen if there simply are no parallel regions. In the resulting kernel
@@ -4658,7 +4659,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
             KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
                 AA.KernelEnvC);
         ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
-            MayUseNestedParallelismC->getType(), AA.NestedParallelism);
+            MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
         AA.setMayUseNestedParallelismOfKernelEnvironment(
             NewMayUseNestedParallelismC);
       }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index c8b58c51d4e6e..659ea3b038e14 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -388,7 +388,7 @@ static APInt findDemandedEltsByAllUsers(Value *V) {
 /// arbitrarily pick 64 bit as our canonical type.  The actual bitwidth doesn't
 /// matter, we just want a consistent type to simplify CSE.
 static ConstantInt *getPreferredVectorIndex(ConstantInt *IndexC) {
-  const unsigned IndexBW = IndexC->getType()->getBitWidth();
+  const unsigned IndexBW = IndexC->getIntegerType()->getBitWidth();
   if (IndexBW == 64 || IndexC->getValue().getActiveBits() > 64)
     return nullptr;
   return ConstantInt::get(IndexC->getContext(),
@@ -2639,7 +2639,7 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf,
     assert(NewInsIndex != -1 && "Did not fold shuffle with unused operand?");
 
     // Index is updated to the potentially translated insertion lane.
-    IndexC = ConstantInt::get(IndexC->getType(), NewInsIndex);
+    IndexC = ConstantInt::get(IndexC->getIntegerType(), NewInsIndex);
     return true;
   };
 
diff --git a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
index 3e5d979f11cc5..c9f63b17cc15a 100644
--- a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
@@ -673,8 +673,8 @@ void ConstantHoistingPass::findBaseConstants(GlobalVariable *BaseGV) {
   llvm::stable_sort(ConstCandVec, [](const ConstantCandidate &LHS,
                                      const ConstantCandidate &RHS) {
     if (LHS.ConstInt->getType() != RHS.ConstInt->getType())
-      return LHS.ConstInt->getType()->getBitWidth() <
-             RHS.ConstInt->getType()->getBitWidth();
+      return LHS.ConstInt->getIntegerType()->getBitWidth() <
+             RHS.ConstInt->getIntegerType()->getBitWidth();
     return LHS.ConstInt->getValue().ult(RHS.ConstInt->getValue());
   });
 
@@ -889,7 +889,7 @@ bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) {
         Type *Ty = ConstInfo.BaseExpr->getType();
         Base = new BitCastInst(ConstInfo.BaseExpr, Ty, "const", IP);
       } else {
-        IntegerType *Ty = ConstInfo.BaseInt->getType();
+        IntegerType *Ty = ConstInfo.BaseInt->getIntegerType();
         Base = new BitCastInst(ConstInfo.BaseInt, Ty, "const", IP);
       }
 
diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index b1add3c42976f..e2341ea4adb0d 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -343,7 +343,7 @@ static bool verifyTripCount(Value *RHS, Loop *L,
     // If the RHS of the compare is equal to the backedge taken count we need
     // to add one to get the trip count.
     if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
-      ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
+      ConstantInt *One = ConstantInt::get(ConstantRHS->getIntegerType(), 1);
       Value *NewRHS = ConstantInt::get(
           ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
       return setLoopComponents(NewRHS, TripCount, Increment,
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index c09cf9c2325c4..1915ccee7e341 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -6294,7 +6294,7 @@ Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) {
   }
   case BitMapKind: {
     // Type of the bitmap (e.g. i59).
-    IntegerType *MapTy = BitMap->getType();
+    IntegerType *MapTy = BitMap->getIntegerType();
 
     // Cast Index to the same type as the bitmap.
     // Note: The Index is <= the number of elements in the table, so
@@ -6669,7 +6669,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
   Value *TableIndex;
   ConstantInt *TableIndexOffset;
   if (UseSwitchConditionAsTableIndex) {
-    TableIndexOffset = ConstantInt::get(MaxCaseVal->getType(), 0);
+    TableIndexOffset = ConstantInt::get(MaxCaseVal->getIntegerType(), 0);
     TableIndex = SI->getCondition();
   } else {
     TableIndexOffset = MinCaseVal;
@@ -6753,7 +6753,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
     // Get the TableIndex'th bit of the bitmask.
     // If this bit is 0 (meaning hole) jump to the default destination,
     // else continue with table lookup.
-    IntegerType *MapTy = TableMask->getType();
+    IntegerType *MapTy = TableMask->getIntegerType();
     Value *MaskIndex =
         Builder.CreateZExtOrTrunc(TableIndex, MapTy, "switch.maskindex");
     Value *Shifted = Builder.CreateLShr(TableMask, MaskIndex, "switch.shifted");
@@ -6976,7 +6976,7 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
   // Replace each case with its trailing zeros number.
   for (auto &Case : SI->cases()) {
     auto *OrigValue = Case.getCaseValue();
-    Case.setValue(ConstantInt::get(OrigValue->getType(),
+    Case.setValue(ConstantInt::get(OrigValue->getIntegerType(),
                                    OrigValue->getValue().countr_zero()));
   }
 

>From 05663ad6641d6675dea9cbbc324f3263df3859b6 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Fri, 1 Dec 2023 14:11:53 +0000
Subject: [PATCH 2/3] [LLVM] Remove ConstantInt's specialisation of getType().

---
 llvm/include/llvm/IR/Constants.h | 9 +--------
 1 file changed, 1 insertion(+), 8 deletions(-)

diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 7bd8bfc477d78..b330903a28038 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -137,7 +137,7 @@ class ConstantInt final : public ConstantData {
 
   /// getBitWidth - Return the bitwidth of this constant.
   unsigned getBitWidth() const {
-    assert(Value::getType()->isIntegerTy() &&
+    assert(getType()->isIntegerTy() &&
            "Returning the bitwidth of a vector constant is not support!");
     return Val.getBitWidth();
   }
@@ -174,13 +174,6 @@ class ConstantInt final : public ConstantData {
   /// Determine if this constant's value is same as an unsigned char.
   bool equalsInt(uint64_t V) const { return Val == V; }
 
-  /// getType - Specialize the getType() method to always return an IntegerType,
-  /// which reduces the amount of casting needed in parts of the compiler.
-  ///
-  inline IntegerType *getType() const {
-    return cast<IntegerType>(Value::getType());
-  }
-
   /// Variant of the getType() method to always return an IntegerType, which
   /// reduces the amount of casting needed in parts of the compiler.
   inline IntegerType *getIntegerType() const {

>From d2ff1c7015265fb26d88b3f574d648f519ea531c Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Wed, 29 Nov 2023 13:54:33 +0000
Subject: [PATCH 3/3] [LLVM][IR] Add native vector support to ConstantInt &
 ConstantFP.

NOTE: For brevity the following takes about ConstantInt but
everything extends to cover ConstantFP as well.

Whilst ConstantInt::get() supports the creation of vectors whereby
each lane has the same value, it achieves this via other constants:

  * ConstantVector for fixed-length vectors
  * ConstantExprs for scalable vectors

ConstantExprs are being deprecated and ConstantVector is not space
efficient for larger vector types. This patch introduces an
alternative by allowing ConstantInt to natively support vector
splats via the IR syntax:

  <N x ty> splat(ty <imm>)

More specifically:

 * IR parsing is extended to support the new syntax.
 * ConstantInt gains the interface getSplat().
 * LLVMContext is extended to map <EC,APInt>->ConstantInt.
 * BitCodeReader/Writer is extended to support vector types.

Whilst this patch adds the base support, more work is required
before it's production ready. For example, there's likely to be
many places where isa<ConstantInt> assumes a scalar type. Accordingly
the default behaviour of ConstantInt::get() remains unchanged but a
set of flag are added to allow wider testing and thus help with the
migration:

  --use-constant-int-for-fixed-length-splat
  --use-constant-fp-for-fixed-length-splat
  --use-constant-int-for-scalable-splat
  --use-constant-fp-for-scalable-splat

NOTE: No change is required to the bitcode format because types and
values are handled separately.

NOTE: Code generation doesn't work out-the-box but the issues look
limited to calls to ConstantInt::getBitWidth() that will need to be
ported.
---
 llvm/include/llvm/AsmParser/LLParser.h    |  4 +-
 llvm/include/llvm/AsmParser/LLToken.h     |  1 +
 llvm/include/llvm/IR/Constants.h          | 15 ++++
 llvm/lib/AsmParser/LLLexer.cpp            |  1 +
 llvm/lib/AsmParser/LLParser.cpp           | 56 ++++++++++++--
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp | 70 ++++++++++-------
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp |  6 +-
 llvm/lib/IR/AsmWriter.cpp                 | 37 +++++++--
 llvm/lib/IR/Constants.cpp                 | 93 ++++++++++++++++++++++-
 llvm/lib/IR/LLVMContextImpl.cpp           |  2 +
 llvm/lib/IR/LLVMContextImpl.h             |  4 +
 llvm/test/Bitcode/constant-splat.ll       | 53 +++++++++++++
 12 files changed, 297 insertions(+), 45 deletions(-)
 create mode 100644 llvm/test/Bitcode/constant-splat.ll

diff --git a/llvm/include/llvm/AsmParser/LLParser.h b/llvm/include/llvm/AsmParser/LLParser.h
index 810f3668d05d4..38f6f08b8f3a1 100644
--- a/llvm/include/llvm/AsmParser/LLParser.h
+++ b/llvm/include/llvm/AsmParser/LLParser.h
@@ -59,7 +59,9 @@ namespace llvm {
       t_Constant,                      // Value in ConstantVal.
       t_InlineAsm,                     // Value in FTy/StrVal/StrVal2/UIntVal.
       t_ConstantStruct,                // Value in ConstantStructElts.
-      t_PackedConstantStruct           // Value in ConstantStructElts.
+      t_PackedConstantStruct,          // Value in ConstantStructElts.
+      t_APSIntSplat,                   // Value in APSIntVal.
+      t_APFloatSplat                   // Value in APFloatVal.
     } Kind = t_LocalID;
 
     LLLexer::LocTy Loc;
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 0683291faae72..dd55afee21033 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -335,6 +335,7 @@ enum Kind {
   kw_extractelement,
   kw_insertelement,
   kw_shufflevector,
+  kw_splat,
   kw_extractvalue,
   kw_insertvalue,
   kw_blockaddress,
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index b330903a28038..b76cb1beecf3c 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -81,6 +81,7 @@ class ConstantInt final : public ConstantData {
   APInt Val;
 
   ConstantInt(IntegerType *Ty, const APInt &V);
+  ConstantInt(VectorType *Ty, const APInt &V);
 
   void destroyConstantImpl();
 
@@ -98,6 +99,13 @@ class ConstantInt final : public ConstantData {
   /// value. Otherwise return a ConstantInt for the given value.
   static Constant *get(Type *Ty, uint64_t V, bool IsSigned = false);
 
+  /// WARNING: Incomplete support, do not use. These methods exist for early
+  /// prototyping, for most use cases ConstantInt::get() should be used.
+  /// Return a ConstantInt with a splat of the given value.
+  static ConstantInt *getSplat(LLVMContext &Context, ElementCount EC,
+                               const APInt &V);
+  static ConstantInt *getSplat(const VectorType *Ty, const APInt &V);
+
   /// Return a ConstantInt with the specified integer value for the specified
   /// type. If the type is wider than 64 bits, the value will be zero-extended
   /// to fit the type, unless IsSigned is true, in which case the value will
@@ -282,6 +290,13 @@ class ConstantFP final : public ConstantData {
   /// value. Otherwise return a ConstantFP for the given value.
   static Constant *get(Type *Ty, const APFloat &V);
 
+  /// WARNING: Incomplete support, do not use. These methods exist for early
+  /// prototyping, for most use cases ConstantFP::get() should be used.
+  /// Return a ConstantFP with a splat of the given value.
+  static ConstantFP *getSplat(LLVMContext &Context, ElementCount EC,
+                              const APFloat &V);
+  static ConstantFP *getSplat(const VectorType *Ty, const APFloat &V);
+
   static Constant *get(Type *Ty, StringRef Str);
   static ConstantFP *get(LLVMContext &Context, const APFloat &V);
   static Constant *getNaN(Type *Ty, bool Negative = false,
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 09a205c445dbe..eb47284feb218 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -697,6 +697,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(uinc_wrap);
   KEYWORD(udec_wrap);
 
+  KEYWORD(splat);
   KEYWORD(vscale);
   KEYWORD(x);
   KEYWORD(blockaddress);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index d236b6cfa9000..94e1a51aa2e75 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -3952,6 +3952,31 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
     return false;
   }
 
+  case lltok::kw_splat: {
+    Lex.Lex();
+    if (parseToken(lltok::lparen, "expected '(' after vector splat"))
+      return true;
+    Constant *C;
+    if (parseGlobalTypeAndValue(C))
+      return true;
+    if (parseToken(lltok::rparen, "expected ')' at end of vector splat"))
+      return true;
+
+    if (auto *CI = dyn_cast<ConstantInt>(C)) {
+      ID.APSIntVal = CI->getValue();
+      ID.Kind = ValID::t_APSIntSplat;
+      return false;
+    }
+
+    if (auto *CFP = dyn_cast<ConstantFP>(C)) {
+      ID.APFloatVal = CFP->getValue();
+      ID.Kind = ValID::t_APFloatSplat;
+      return false;
+    }
+
+    return tokError("invalid splat operand");
+  }
+
   case lltok::kw_getelementptr:
   case lltok::kw_shufflevector:
   case lltok::kw_insertelement:
@@ -5716,9 +5741,23 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
     ID.APSIntVal = ID.APSIntVal.extOrTrunc(Ty->getPrimitiveSizeInBits());
     V = ConstantInt::get(Context, ID.APSIntVal);
     return false;
+  case ValID::t_APSIntSplat:
+    if (!Ty->isVectorTy() || !Ty->getScalarType()->isIntegerTy())
+      return error(ID.Loc, "expected an integer vector result");
+    if (ID.APSIntVal.getBitWidth() !=
+        cast<IntegerType>(Ty->getScalarType())->getBitWidth())
+      return error(ID.Loc, "operand type must match result element type");
+    V = ConstantInt::getSplat(cast<VectorType>(Ty), ID.APSIntVal);
+    return false;
   case ValID::t_APFloat:
-    if (!Ty->isFloatingPointTy() ||
-        !ConstantFP::isValueValidForType(Ty, ID.APFloatVal))
+  case ValID::t_APFloatSplat: {
+    if ((ID.Kind == ValID::t_APFloat && !Ty->isFloatingPointTy()) ||
+        (ID.Kind == ValID::t_APFloatSplat && !Ty->isVectorTy()))
+      return error(ID.Loc, "floating point constant invalid for type");
+
+    Type *ScalarTy = Ty->getScalarType();
+    if (!ScalarTy->isFloatingPointTy() ||
+        !ConstantFP::isValueValidForType(ScalarTy, ID.APFloatVal))
       return error(ID.Loc, "floating point constant invalid for type");
 
     // The lexer has no type info, so builds all half, bfloat, float, and double
@@ -5727,13 +5766,13 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
       // Check for signaling before potentially converting and losing that info.
       bool IsSNAN = ID.APFloatVal.isSignaling();
       bool Ignored;
-      if (Ty->isHalfTy())
+      if (ScalarTy->isHalfTy())
         ID.APFloatVal.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
                               &Ignored);
-      else if (Ty->isBFloatTy())
+      else if (ScalarTy->isBFloatTy())
         ID.APFloatVal.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven,
                               &Ignored);
-      else if (Ty->isFloatTy())
+      else if (ScalarTy->isFloatTy())
         ID.APFloatVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
                               &Ignored);
       if (IsSNAN) {
@@ -5745,13 +5784,18 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
                                          ID.APFloatVal.isNegative(), &Payload);
       }
     }
-    V = ConstantFP::get(Context, ID.APFloatVal);
+
+    if (auto *VTy = dyn_cast<VectorType>(Ty))
+      V = ConstantFP::getSplat(VTy, ID.APFloatVal);
+    else
+      V = ConstantFP::get(Context, ID.APFloatVal);
 
     if (V->getType() != Ty)
       return error(ID.Loc, "floating point constant does not have type '" +
                                getTypeString(Ty) + "'");
 
     return false;
+  }
   case ValID::t_Null:
     if (!Ty->isPointerTy())
       return error(ID.Loc, "null must be a pointer type");
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index e4c3770946b3a..b661d36fb6854 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -3022,50 +3022,62 @@ Error BitcodeReader::parseConstants() {
       V = Constant::getNullValue(CurTy);
       break;
     case bitc::CST_CODE_INTEGER:   // INTEGER: [intval]
-      if (!CurTy->isIntegerTy() || Record.empty())
+      if (!CurTy->isIntOrIntVectorTy() || Record.empty())
         return error("Invalid integer const record");
-      V = ConstantInt::get(CurTy, decodeSignRotatedValue(Record[0]));
+
+      if (auto *VTy = dyn_cast<VectorType>(CurTy)) {
+        auto *ScalarTy = cast<IntegerType>(VTy->getScalarType());
+        unsigned BitWidth = ScalarTy->getBitWidth();
+        APInt VInt(BitWidth, decodeSignRotatedValue(Record[0]));
+        V = ConstantInt::getSplat(VTy, VInt);
+      } else
+        V = ConstantInt::get(CurTy, decodeSignRotatedValue(Record[0]));
       break;
     case bitc::CST_CODE_WIDE_INTEGER: {// WIDE_INTEGER: [n x intval]
-      if (!CurTy->isIntegerTy() || Record.empty())
+      if (!CurTy->isIntOrIntVectorTy() || Record.empty())
         return error("Invalid wide integer const record");
 
-      APInt VInt =
-          readWideAPInt(Record, cast<IntegerType>(CurTy)->getBitWidth());
-      V = ConstantInt::get(Context, VInt);
-
+      auto *ScalarTy = cast<IntegerType>(CurTy->getScalarType());
+      APInt VInt = readWideAPInt(Record, ScalarTy->getBitWidth());
+      if (auto *VTy = dyn_cast<VectorType>(CurTy))
+        V = ConstantInt::getSplat(VTy, VInt);
+      else
+        V = ConstantInt::get(Context, VInt);
       break;
     }
     case bitc::CST_CODE_FLOAT: {    // FLOAT: [fpval]
       if (Record.empty())
         return error("Invalid float const record");
-      if (CurTy->isHalfTy())
-        V = ConstantFP::get(Context, APFloat(APFloat::IEEEhalf(),
-                                             APInt(16, (uint16_t)Record[0])));
-      else if (CurTy->isBFloatTy())
-        V = ConstantFP::get(Context, APFloat(APFloat::BFloat(),
-                                             APInt(16, (uint32_t)Record[0])));
-      else if (CurTy->isFloatTy())
-        V = ConstantFP::get(Context, APFloat(APFloat::IEEEsingle(),
-                                             APInt(32, (uint32_t)Record[0])));
-      else if (CurTy->isDoubleTy())
-        V = ConstantFP::get(Context, APFloat(APFloat::IEEEdouble(),
-                                             APInt(64, Record[0])));
-      else if (CurTy->isX86_FP80Ty()) {
+
+      APFloat Val(APFloat::Bogus());
+      auto *ScalarTy = CurTy->getScalarType();
+      if (ScalarTy->isHalfTy())
+        Val = APFloat(APFloat::IEEEhalf(), APInt(16, (uint16_t)Record[0]));
+      else if (ScalarTy->isBFloatTy())
+        Val = APFloat(APFloat::BFloat(), APInt(16, (uint32_t)Record[0]));
+      else if (ScalarTy->isFloatTy())
+        Val = APFloat(APFloat::IEEEsingle(), APInt(32, (uint32_t)Record[0]));
+      else if (ScalarTy->isDoubleTy())
+        Val = APFloat(APFloat::IEEEdouble(), APInt(64, Record[0]));
+      else if (ScalarTy->isX86_FP80Ty()) {
         // Bits are not stored the same way as a normal i80 APInt, compensate.
         uint64_t Rearrange[2];
         Rearrange[0] = (Record[1] & 0xffffLL) | (Record[0] << 16);
         Rearrange[1] = Record[0] >> 48;
-        V = ConstantFP::get(Context, APFloat(APFloat::x87DoubleExtended(),
-                                             APInt(80, Rearrange)));
-      } else if (CurTy->isFP128Ty())
-        V = ConstantFP::get(Context, APFloat(APFloat::IEEEquad(),
-                                             APInt(128, Record)));
-      else if (CurTy->isPPC_FP128Ty())
-        V = ConstantFP::get(Context, APFloat(APFloat::PPCDoubleDouble(),
-                                             APInt(128, Record)));
-      else
+        Val = APFloat(APFloat::x87DoubleExtended(), APInt(80, Rearrange));
+      } else if (ScalarTy->isFP128Ty())
+        Val = APFloat(APFloat::IEEEquad(), APInt(128, Record));
+      else if (ScalarTy->isPPC_FP128Ty())
+        Val = APFloat(APFloat::PPCDoubleDouble(), APInt(128, Record));
+      else {
         V = UndefValue::get(CurTy);
+        break;
+      }
+
+      if (auto *VTy = dyn_cast<VectorType>(CurTy))
+        V = ConstantFP::getSplat(VTy, Val);
+      else
+        V = ConstantFP::get(Context, Val);
       break;
     }
 
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 8239775d04865..0f5b9ff9ebd72 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -2577,18 +2577,18 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
     } else if (isa<UndefValue>(C)) {
       Code = bitc::CST_CODE_UNDEF;
     } else if (const ConstantInt *IV = dyn_cast<ConstantInt>(C)) {
-      if (IV->getBitWidth() <= 64) {
+      if (IV->getValue().getBitWidth() <= 64) {
         uint64_t V = IV->getSExtValue();
         emitSignedInt64(Record, V);
         Code = bitc::CST_CODE_INTEGER;
         AbbrevToUse = CONSTANTS_INTEGER_ABBREV;
-      } else {                             // Wide integers, > 64 bits in size.
+      } else { // Wide integers, > 64 bits in size.
         emitWideAPInt(Record, IV->getValue());
         Code = bitc::CST_CODE_WIDE_INTEGER;
       }
     } else if (const ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
       Code = bitc::CST_CODE_FLOAT;
-      Type *Ty = CFP->getType();
+      Type *Ty = CFP->getType()->getScalarType();
       if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() ||
           Ty->isDoubleTy()) {
         Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue());
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index fabc79adbd33d..e37da7c460f26 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1394,16 +1394,32 @@ static void WriteOptimizationInfo(raw_ostream &Out, const User *U) {
 static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
                                   AsmWriterContext &WriterCtx) {
   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) {
-    if (CI->getType()->isIntegerTy(1)) {
-      Out << (CI->getZExtValue() ? "true" : "false");
-      return;
+    if (CI->getType()->isVectorTy()) {
+      Out << "splat (";
+      WriterCtx.TypePrinter->print(CI->getType()->getScalarType(), Out);
+      Out << " ";
     }
-    Out << CI->getValue();
+
+    if (CI->getType()->getScalarType()->isIntegerTy(1))
+      Out << (CI->getZExtValue() ? "true" : "false");
+    else
+      Out << CI->getValue();
+
+    if (CI->getType()->isVectorTy())
+      Out << ")";
+
     return;
   }
 
   if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CV)) {
     const APFloat &APF = CFP->getValueAPF();
+
+    if (CFP->getType()->isVectorTy()) {
+      Out << "splat (";
+      WriterCtx.TypePrinter->print(CFP->getType()->getScalarType(), Out);
+      Out << " ";
+    }
+
     if (&APF.getSemantics() == &APFloat::IEEEsingle() ||
         &APF.getSemantics() == &APFloat::IEEEdouble()) {
       // We would like to output the FP constant value in exponential notation,
@@ -1429,6 +1445,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
         // Reparse stringized version!
         if (APFloat(APFloat::IEEEdouble(), StrVal).convertToDouble() == Val) {
           Out << StrVal;
+
+          if (CFP->getType()->isVectorTy())
+            Out << ")";
+
           return;
         }
       }
@@ -1454,6 +1474,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
         }
       }
       Out << format_hex(apf.bitcastToAPInt().getZExtValue(), 0, /*Upper=*/true);
+
+      if (CFP->getType()->isVectorTy())
+        Out << ")";
+
       return;
     }
 
@@ -1468,7 +1492,6 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
                                   /*Upper=*/true);
       Out << format_hex_no_prefix(API.getLoBits(64).getZExtValue(), 16,
                                   /*Upper=*/true);
-      return;
     } else if (&APF.getSemantics() == &APFloat::IEEEquad()) {
       Out << 'L';
       Out << format_hex_no_prefix(API.getLoBits(64).getZExtValue(), 16,
@@ -1491,6 +1514,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
                                   /*Upper=*/true);
     } else
       llvm_unreachable("Unsupported floating point type");
+
+    if (CFP->getType()->isVectorTy())
+      Out << ")";
+
     return;
   }
 
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index bc55d5b485271..64995fb46689b 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -35,6 +35,20 @@
 using namespace llvm;
 using namespace PatternMatch;
 
+// As set of temporary options to help migrate how splats are represented.
+static cl::opt<bool> UseConstantIntForFixedLengthSplat(
+    "use-constant-int-for-fixed-length-splat", cl::init(false), cl::Hidden,
+    cl::desc("Use ConstantInt's native fixed-length vector splat support."));
+static cl::opt<bool> UseConstantFPForFixedLengthSplat(
+    "use-constant-fp-for-fixed-length-splat", cl::init(false), cl::Hidden,
+    cl::desc("Use ConstantFP's native fixed-length vector splat support."));
+static cl::opt<bool> UseConstantIntForScalableSplat(
+    "use-constant-int-for-scalable-splat", cl::init(false), cl::Hidden,
+    cl::desc("Use ConstantInt's native scalable vector splat support."));
+static cl::opt<bool> UseConstantFPForScalableSplat(
+    "use-constant-fp-for-scalable-splat", cl::init(false), cl::Hidden,
+    cl::desc("Use ConstantFP's native scalable vector splat support."));
+
 //===----------------------------------------------------------------------===//
 //                              Constant Class
 //===----------------------------------------------------------------------===//
@@ -830,6 +844,13 @@ ConstantInt::ConstantInt(IntegerType *Ty, const APInt &V)
   assert(V.getBitWidth() == Ty->getBitWidth() && "Invalid constant for type");
 }
 
+ConstantInt::ConstantInt(VectorType *Ty, const APInt &V)
+    : ConstantData(Ty, ConstantIntVal), Val(V) {
+  assert(V.getBitWidth() ==
+             cast<IntegerType>(Ty->getElementType())->getBitWidth() &&
+         "Invalid constant for type");
+}
+
 ConstantInt *ConstantInt::getTrue(LLVMContext &Context) {
   LLVMContextImpl *pImpl = Context.pImpl;
   if (!pImpl->TheTrueVal)
@@ -915,6 +936,32 @@ ConstantInt *ConstantInt::get(IntegerType* Ty, StringRef Str, uint8_t radix) {
   return get(Ty->getContext(), APInt(Ty->getBitWidth(), Str, radix));
 }
 
+// Get a ConstantInt vector with each lane set to the same APInt.
+ConstantInt *ConstantInt::getSplat(LLVMContext &Context, ElementCount EC,
+                                   const APInt &V) {
+  IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
+  VectorType *VTy = VectorType::get(ITy, EC);
+
+  // Get an existing value or the insertion position.
+  std::unique_ptr<ConstantInt> &Slot =
+      Context.pImpl->IntSplatConstants[std::make_pair(EC, V)];
+  if (!Slot)
+    Slot.reset(new ConstantInt(VTy, V));
+
+  assert(Slot->getType() == VTy);
+  return Slot.get();
+}
+
+// Get a ConstantInt vector with each lane set to the same APInt.
+ConstantInt *ConstantInt::getSplat(const VectorType *Ty, const APInt &V) {
+  assert(Ty->getElementType()->isIntegerTy() &&
+         "Expected integer vector type!");
+  assert(cast<IntegerType>(Ty->getElementType())->getBitWidth() ==
+             V.getBitWidth() &&
+         "Expected value of same bitwidth as vector element type!");
+  return getSplat(Ty->getContext(), Ty->getElementCount(), V);
+}
+
 /// Remove the constant from the constant table.
 void ConstantInt::destroyConstantImpl() {
   llvm_unreachable("You can't ConstantInt->destroyConstantImpl()!");
@@ -1036,7 +1083,7 @@ Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
 
 ConstantFP::ConstantFP(Type *Ty, const APFloat &V)
     : ConstantData(Ty, ConstantFPVal), Val(V) {
-  assert(&V.getSemantics() == &Ty->getFltSemantics() &&
+  assert(&V.getSemantics() == &Ty->getScalarType()->getFltSemantics() &&
          "FP type Mismatch");
 }
 
@@ -1044,6 +1091,30 @@ bool ConstantFP::isExactlyValue(const APFloat &V) const {
   return Val.bitwiseIsEqual(V);
 }
 
+// Get a ConstantFP vector with each lane set to the same APFloat.
+ConstantFP *ConstantFP::getSplat(LLVMContext &Context, ElementCount EC,
+                                 const APFloat &V) {
+  Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics());
+  VectorType *VTy = VectorType::get(EltTy, EC);
+
+  // Get an existing value or the insertion position.
+  std::unique_ptr<ConstantFP> &Slot =
+      Context.pImpl->FPSplatConstants[std::make_pair(EC, V)];
+  if (!Slot)
+    Slot.reset(new ConstantFP(VTy, V));
+
+  assert(Slot->getType() == VTy);
+  return Slot.get();
+}
+
+// Get a ConstantFP vector with each lane set to the same APFloat.
+ConstantFP *ConstantFP::getSplat(const VectorType *Ty, const APFloat &V) {
+  assert(Ty->getElementType() ==
+             Type::getFloatingPointTy(Ty->getContext(), V.getSemantics()) &&
+         "Expected value of same bitwidth as vector element type!");
+  return getSplat(Ty->getContext(), Ty->getElementCount(), V);
+}
+
 /// Remove the constant from the constant table.
 void ConstantFP::destroyConstantImpl() {
   llvm_unreachable("You can't ConstantFP->destroyConstantImpl()!");
@@ -1384,6 +1455,16 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
 
 Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
   if (!EC.isScalable()) {
+    // Maintain special handling of zero.
+    if (!V->isNullValue()) {
+      if (UseConstantIntForFixedLengthSplat && isa<ConstantInt>(V))
+        return ConstantInt::getSplat(V->getContext(), EC,
+                                     cast<ConstantInt>(V)->getValue());
+      if (UseConstantFPForFixedLengthSplat && isa<ConstantFP>(V))
+        return ConstantFP::getSplat(V->getContext(), EC,
+                                    cast<ConstantFP>(V)->getValue());
+    }
+
     // If this splat is compatible with ConstantDataVector, use it instead of
     // ConstantVector.
     if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) &&
@@ -1394,6 +1475,16 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
     return get(Elts);
   }
 
+  // Maintain special handling of zero.
+  if (!V->isNullValue()) {
+    if (UseConstantIntForScalableSplat && isa<ConstantInt>(V))
+      return ConstantInt::getSplat(V->getContext(), EC,
+                                   cast<ConstantInt>(V)->getValue());
+    if (UseConstantFPForScalableSplat && isa<ConstantFP>(V))
+      return ConstantFP::getSplat(V->getContext(), EC,
+                                  cast<ConstantFP>(V)->getValue());
+  }
+
   Type *VTy = VectorType::get(V->getType(), EC);
 
   if (V->isNullValue())
diff --git a/llvm/lib/IR/LLVMContextImpl.cpp b/llvm/lib/IR/LLVMContextImpl.cpp
index 15c90a4fe7b2e..a0bf9cae7926b 100644
--- a/llvm/lib/IR/LLVMContextImpl.cpp
+++ b/llvm/lib/IR/LLVMContextImpl.cpp
@@ -119,7 +119,9 @@ LLVMContextImpl::~LLVMContextImpl() {
   IntZeroConstants.clear();
   IntOneConstants.clear();
   IntConstants.clear();
+  IntSplatConstants.clear();
   FPConstants.clear();
+  FPSplatConstants.clear();
   CDSConstants.clear();
 
   // Destroy attribute node lists.
diff --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h
index 6a20291344989..2ee1080a1ffa2 100644
--- a/llvm/lib/IR/LLVMContextImpl.h
+++ b/llvm/lib/IR/LLVMContextImpl.h
@@ -1488,8 +1488,12 @@ class LLVMContextImpl {
   DenseMap<unsigned, std::unique_ptr<ConstantInt>> IntZeroConstants;
   DenseMap<unsigned, std::unique_ptr<ConstantInt>> IntOneConstants;
   DenseMap<APInt, std::unique_ptr<ConstantInt>> IntConstants;
+  DenseMap<std::pair<ElementCount, APInt>, std::unique_ptr<ConstantInt>>
+      IntSplatConstants;
 
   DenseMap<APFloat, std::unique_ptr<ConstantFP>> FPConstants;
+  DenseMap<std::pair<ElementCount, APFloat>, std::unique_ptr<ConstantFP>>
+      FPSplatConstants;
 
   FoldingSet<AttributeImpl> AttrsSet;
   FoldingSet<AttributeListImpl> AttrsLists;
diff --git a/llvm/test/Bitcode/constant-splat.ll b/llvm/test/Bitcode/constant-splat.ll
new file mode 100644
index 0000000000000..558d83e6c21ee
--- /dev/null
+++ b/llvm/test/Bitcode/constant-splat.ll
@@ -0,0 +1,53 @@
+; RUN: llvm-as < %s | llvm-dis | llvm-as | llvm-dis | FileCheck %s
+
+; CHECK: @constant.splat.i1 = constant <1 x i1> splat (i1 true)
+ at constant.splat.i1 = constant <1 x i1> splat (i1 true)
+
+; CHECK: @constant.splat.i32 = constant <5 x i32> splat (i32 7)
+ at constant.splat.i32 = constant <5 x i32> splat (i32 7)
+
+; CHECK: @constant.splat.i128 = constant <7 x i128> splat (i128 85070591730234615870450834276742070272)
+ at constant.splat.i128 = constant <7 x i128> splat (i128 85070591730234615870450834276742070272)
+
+; CHECK: @constant.splat.f16 = constant <2 x half> splat (half 0xHBC00)
+ at constant.splat.f16 = constant <2 x half> splat (half 0xHBC00)
+
+; CHECK: @constant.splat.f32 = constant <4 x float> splat (float -2.000000e+00)
+ at constant.splat.f32 = constant <4 x float> splat (float -2.000000e+00)
+
+; CHECK: @constant.splat.f64 = constant <6 x double> splat (double -3.000000e+00)
+ at constant.splat.f64 = constant <6 x double> splat (double -3.000000e+00)
+
+; CHECK: @constant.splat.128 = constant <8 x fp128> splat (fp128 0xL00000000000000018000000000000000)
+ at constant.splat.128 = constant <8 x fp128> splat (fp128 0xL00000000000000018000000000000000)
+
+; CHECK: @constant.splat.bf16 = constant <1 x bfloat> splat (bfloat 0xRC0A0)
+ at constant.splat.bf16 = constant <1 x bfloat> splat (bfloat 0xRC0A0)
+
+; CHECK: @constant.splat.x86_fp80 = constant <3 x x86_fp80> splat (x86_fp80 0xK4000C8F5C28F5C28F800)
+ at constant.splat.x86_fp80 = constant <3 x x86_fp80> splat (x86_fp80 0xK4000C8F5C28F5C28F800)
+
+; CHECK: @constant.splat.ppc_fp128 = constant <7 x ppc_fp128> splat (ppc_fp128 0xM80000000000000000000000000000000)
+ at constant.splat.ppc_fp128 = constant <7 x ppc_fp128> splat (ppc_fp128 0xM80000000000000000000000000000000)
+
+define void @add_fixed_lenth_vector_splat_i32(<4 x i32> %a) {
+; CHECK: %add = add <4 x i32> %a, splat (i32 137)
+  %add = add <4 x i32> %a, splat (i32 137)
+  ret void
+}
+
+define <4 x i32> @ret_fixed_lenth_vector_splat_i32() {
+; CHECK: ret <4 x i32> splat (i32 56)
+  ret <4 x i32> splat (i32 56)
+}
+
+define void @add_fixed_lenth_vector_splat_double(<vscale x 2 x double> %a) {
+; CHECK: %add = fadd <vscale x 2 x double> %a, splat (double 5.700000e+00)
+  %add = fadd <vscale x 2 x double> %a, splat (double 5.700000e+00)
+  ret void
+}
+
+define <vscale x 4 x i32> @ret_scalable_vector_splat_i32() {
+; CHECK: ret <vscale x 4 x i32> splat (i32 78)
+  ret <vscale x 4 x i32> splat (i32 78)
+}



More information about the cfe-commits mailing list