[llvm] [LLVM][IR] Add native vector support to ConstantInt & ConstantFP. (PR #74502)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 13 09:07:57 PST 2024


https://github.com/paulwalker-arm updated https://github.com/llvm/llvm-project/pull/74502

>From 1a494eca2018dfc94e75ded00ab81a3efd34b15f Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Wed, 29 Nov 2023 13:54:33 +0000
Subject: [PATCH 1/2] [LLVM][IR] Add native vector support to ConstantInt &
 ConstantFP.

NOTE: For brevity the following talks about ConstantInt but
everything extends to cover ConstantFP as well.

Whilst ConstantInt::get() supports the creation of vectors whereby
each lane has the same value, it achieves this via other constants:

  * ConstantVector for fixed-length vectors
  * ConstantExprs for scalable vectors

However, ConstantExprs are being deprecated and ConstantVector is
not space efficient for larger vector types. By extending ConstantInt
we can represent vector splats by only storing the underlying scalar
value.

More specifically:

 * ConstantInt gains an ElementCount variant of get().
 * LLVMContext is extended to map <EC,APInt>->ConstantInt.
 * BitcodeReader/Writer support is extended to allow vector types.

Whilst this patch adds the base support, more work is required
before it's production ready. For example, there's likely to be
many places where isa<ConstantInt> assumes a scalar type. Accordingly
the default behaviour of ConstantInt::get() remains unchanged but a
set of flags are added to allow wider testing and thus help with the
migration:

  --use-constant-int-for-fixed-length-splat
  --use-constant-fp-for-fixed-length-splat
  --use-constant-int-for-scalable-splat
  --use-constant-fp-for-scalable-splat

NOTE: No change is required to the bitcode format because types and
values are handled separately.

NOTE: For similar reasons as above, code generation doesn't work
out-the-box.
---
 llvm/include/llvm/IR/Constants.h          | 12 +++-
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp | 55 +++++++--------
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp |  2 +-
 llvm/lib/IR/AsmWriter.cpp                 | 27 ++++++--
 llvm/lib/IR/Constants.cpp                 | 82 ++++++++++++++++++++++-
 llvm/lib/IR/LLVMContextImpl.cpp           |  2 +
 llvm/lib/IR/LLVMContextImpl.h             |  4 ++
 llvm/test/Bitcode/constant-splat.ll       | 61 +++++++++++++++++
 8 files changed, 208 insertions(+), 37 deletions(-)
 create mode 100644 llvm/test/Bitcode/constant-splat.ll

diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index b5dcc7fbc1d929..39eec1b738fabb 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -81,7 +81,7 @@ class ConstantInt final : public ConstantData {
 
   APInt Val;
 
-  ConstantInt(IntegerType *Ty, const APInt &V);
+  ConstantInt(Type *Ty, const APInt &V);
 
   void destroyConstantImpl();
 
@@ -123,6 +123,12 @@ class ConstantInt final : public ConstantData {
   /// type is the integer type that corresponds to the bit width of the value.
   static ConstantInt *get(LLVMContext &Context, const APInt &V);
 
+  /// Return a ConstantInt with the specified value and an implied Type. The
+  /// type is the vector type whose integer element type corresponds to the bit
+  /// width of the value.
+  static ConstantInt *get(LLVMContext &Context, ElementCount EC,
+                          const APInt &V);
+
   /// Return a ConstantInt constructed from the string strStart with the given
   /// radix.
   static ConstantInt *get(IntegerType *Ty, StringRef Str, uint8_t Radix);
@@ -136,7 +142,7 @@ class ConstantInt final : public ConstantData {
   /// Return the constant's value.
   inline const APInt &getValue() const { return Val; }
 
-  /// getBitWidth - Return the bitwidth of this constant.
+  /// getBitWidth - Return the scalar bitwidth of this constant.
   unsigned getBitWidth() const { return Val.getBitWidth(); }
 
   /// Return the constant as a 64-bit unsigned integer value after it
@@ -281,6 +287,8 @@ class ConstantFP final : public ConstantData {
 
   static Constant *get(Type *Ty, StringRef Str);
   static ConstantFP *get(LLVMContext &Context, const APFloat &V);
+  static ConstantFP *get(LLVMContext &Context, ElementCount EC,
+                         const APFloat &V);
   static Constant *getNaN(Type *Ty, bool Negative = false,
                           uint64_t Payload = 0);
   static Constant *getQNaN(Type *Ty, bool Negative = false,
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 515a1d0caa0415..832907a3f53f5f 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -3060,48 +3060,49 @@ Error BitcodeReader::parseConstants() {
       V = Constant::getNullValue(CurTy);
       break;
     case bitc::CST_CODE_INTEGER:   // INTEGER: [intval]
-      if (!CurTy->isIntegerTy() || Record.empty())
+      if (!CurTy->isIntOrIntVectorTy() || Record.empty())
         return error("Invalid integer const record");
       V = ConstantInt::get(CurTy, decodeSignRotatedValue(Record[0]));
       break;
     case bitc::CST_CODE_WIDE_INTEGER: {// WIDE_INTEGER: [n x intval]
-      if (!CurTy->isIntegerTy() || Record.empty())
+      if (!CurTy->isIntOrIntVectorTy() || Record.empty())
         return error("Invalid wide integer const record");
 
-      APInt VInt =
-          readWideAPInt(Record, cast<IntegerType>(CurTy)->getBitWidth());
-      V = ConstantInt::get(Context, VInt);
-
+      auto *ScalarTy = cast<IntegerType>(CurTy->getScalarType());
+      APInt VInt = readWideAPInt(Record, ScalarTy->getBitWidth());
+      V = ConstantInt::get(CurTy, VInt);
       break;
     }
     case bitc::CST_CODE_FLOAT: {    // FLOAT: [fpval]
       if (Record.empty())
         return error("Invalid float const record");
-      if (CurTy->isHalfTy())
-        V = ConstantFP::get(Context, APFloat(APFloat::IEEEhalf(),
-                                             APInt(16, (uint16_t)Record[0])));
-      else if (CurTy->isBFloatTy())
-        V = ConstantFP::get(Context, APFloat(APFloat::BFloat(),
-                                             APInt(16, (uint32_t)Record[0])));
-      else if (CurTy->isFloatTy())
-        V = ConstantFP::get(Context, APFloat(APFloat::IEEEsingle(),
-                                             APInt(32, (uint32_t)Record[0])));
-      else if (CurTy->isDoubleTy())
-        V = ConstantFP::get(Context, APFloat(APFloat::IEEEdouble(),
-                                             APInt(64, Record[0])));
-      else if (CurTy->isX86_FP80Ty()) {
+
+      auto *ScalarTy = CurTy->getScalarType();
+      if (ScalarTy->isHalfTy())
+        V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEhalf(),
+                                           APInt(16, (uint16_t)Record[0])));
+      else if (ScalarTy->isBFloatTy())
+        V = ConstantFP::get(
+            CurTy, APFloat(APFloat::BFloat(), APInt(16, (uint32_t)Record[0])));
+      else if (ScalarTy->isFloatTy())
+        V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEsingle(),
+                                           APInt(32, (uint32_t)Record[0])));
+      else if (ScalarTy->isDoubleTy())
+        V = ConstantFP::get(
+            CurTy, APFloat(APFloat::IEEEdouble(), APInt(64, Record[0])));
+      else if (ScalarTy->isX86_FP80Ty()) {
         // Bits are not stored the same way as a normal i80 APInt, compensate.
         uint64_t Rearrange[2];
         Rearrange[0] = (Record[1] & 0xffffLL) | (Record[0] << 16);
         Rearrange[1] = Record[0] >> 48;
-        V = ConstantFP::get(Context, APFloat(APFloat::x87DoubleExtended(),
-                                             APInt(80, Rearrange)));
-      } else if (CurTy->isFP128Ty())
-        V = ConstantFP::get(Context, APFloat(APFloat::IEEEquad(),
-                                             APInt(128, Record)));
-      else if (CurTy->isPPC_FP128Ty())
-        V = ConstantFP::get(Context, APFloat(APFloat::PPCDoubleDouble(),
-                                             APInt(128, Record)));
+        V = ConstantFP::get(
+            CurTy, APFloat(APFloat::x87DoubleExtended(), APInt(80, Rearrange)));
+      } else if (ScalarTy->isFP128Ty())
+        V = ConstantFP::get(CurTy,
+                            APFloat(APFloat::IEEEquad(), APInt(128, Record)));
+      else if (ScalarTy->isPPC_FP128Ty())
+        V = ConstantFP::get(
+            CurTy, APFloat(APFloat::PPCDoubleDouble(), APInt(128, Record)));
       else
         V = UndefValue::get(CurTy);
       break;
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 13be0b0c3307fb..656f2a6ce870f5 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -2624,7 +2624,7 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
       }
     } else if (const ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
       Code = bitc::CST_CODE_FLOAT;
-      Type *Ty = CFP->getType();
+      Type *Ty = CFP->getType()->getScalarType();
       if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() ||
           Ty->isDoubleTy()) {
         Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue());
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 0ae720e8b7ce8c..1fcda6c384d96d 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1502,16 +1502,35 @@ static void WriteAPFloatInternal(raw_ostream &Out, const APFloat &APF) {
 static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
                                   AsmWriterContext &WriterCtx) {
   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) {
-    if (CI->getType()->isIntegerTy(1)) {
-      Out << (CI->getZExtValue() ? "true" : "false");
-      return;
+    if (CI->getType()->isVectorTy()) {
+      Out << "splat (";
+      WriterCtx.TypePrinter->print(CI->getType()->getScalarType(), Out);
+      Out << " ";
     }
-    Out << CI->getValue();
+
+    if (CI->getType()->getScalarType()->isIntegerTy(1))
+      Out << (CI->getZExtValue() ? "true" : "false");
+    else
+      Out << CI->getValue();
+
+    if (CI->getType()->isVectorTy())
+      Out << ")";
+
     return;
   }
 
   if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CV)) {
+    if (CFP->getType()->isVectorTy()) {
+      Out << "splat (";
+      WriterCtx.TypePrinter->print(CFP->getType()->getScalarType(), Out);
+      Out << " ";
+    }
+
     WriteAPFloatInternal(Out, CFP->getValueAPF());
+
+    if (CFP->getType()->isVectorTy())
+      Out << ")";
+
     return;
   }
 
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index a38b912164b130..b04d7955afe670 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -35,6 +35,20 @@
 using namespace llvm;
 using namespace PatternMatch;
 
+// As set of temporary options to help migrate how splats are represented.
+static cl::opt<bool> UseConstantIntForFixedLengthSplat(
+    "use-constant-int-for-fixed-length-splat", cl::init(false), cl::Hidden,
+    cl::desc("Use ConstantInt's native fixed-length vector splat support."));
+static cl::opt<bool> UseConstantFPForFixedLengthSplat(
+    "use-constant-fp-for-fixed-length-splat", cl::init(false), cl::Hidden,
+    cl::desc("Use ConstantFP's native fixed-length vector splat support."));
+static cl::opt<bool> UseConstantIntForScalableSplat(
+    "use-constant-int-for-scalable-splat", cl::init(false), cl::Hidden,
+    cl::desc("Use ConstantInt's native scalable vector splat support."));
+static cl::opt<bool> UseConstantFPForScalableSplat(
+    "use-constant-fp-for-scalable-splat", cl::init(false), cl::Hidden,
+    cl::desc("Use ConstantFP's native scalable vector splat support."));
+
 //===----------------------------------------------------------------------===//
 //                              Constant Class
 //===----------------------------------------------------------------------===//
@@ -825,9 +839,11 @@ bool Constant::isManifestConstant() const {
 //                                ConstantInt
 //===----------------------------------------------------------------------===//
 
-ConstantInt::ConstantInt(IntegerType *Ty, const APInt &V)
+ConstantInt::ConstantInt(Type *Ty, const APInt &V)
     : ConstantData(Ty, ConstantIntVal), Val(V) {
-  assert(V.getBitWidth() == Ty->getBitWidth() && "Invalid constant for type");
+  assert(V.getBitWidth() ==
+             cast<IntegerType>(Ty->getScalarType())->getBitWidth() &&
+         "Invalid constant for type");
 }
 
 ConstantInt *ConstantInt::getTrue(LLVMContext &Context) {
@@ -885,6 +901,26 @@ ConstantInt *ConstantInt::get(LLVMContext &Context, const APInt &V) {
   return Slot.get();
 }
 
+// Get a ConstantInt vector with each lane set to the same APInt.
+ConstantInt *ConstantInt::get(LLVMContext &Context, ElementCount EC,
+                              const APInt &V) {
+  // Get an existing value or the insertion position.
+  std::unique_ptr<ConstantInt> &Slot =
+      Context.pImpl->IntSplatConstants[std::make_pair(EC, V)];
+  if (!Slot) {
+    IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
+    VectorType *VTy = VectorType::get(ITy, EC);
+    Slot.reset(new ConstantInt(VTy, V));
+  }
+
+#ifndef NDEBUG
+  IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
+  VectorType *VTy = VectorType::get(ITy, EC);
+  assert(Slot->getType() == VTy);
+#endif
+  return Slot.get();
+}
+
 Constant *ConstantInt::get(Type *Ty, uint64_t V, bool isSigned) {
   Constant *C = get(cast<IntegerType>(Ty->getScalarType()), V, isSigned);
 
@@ -1024,6 +1060,26 @@ ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) {
   return Slot.get();
 }
 
+// Get a ConstantFP vector with each lane set to the same APFloat.
+ConstantFP *ConstantFP::get(LLVMContext &Context, ElementCount EC,
+                            const APFloat &V) {
+  // Get an existing value or the insertion position.
+  std::unique_ptr<ConstantFP> &Slot =
+      Context.pImpl->FPSplatConstants[std::make_pair(EC, V)];
+  if (!Slot) {
+    Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics());
+    VectorType *VTy = VectorType::get(EltTy, EC);
+    Slot.reset(new ConstantFP(VTy, V));
+  }
+
+#ifndef NDEBUG
+  Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics());
+  VectorType *VTy = VectorType::get(EltTy, EC);
+  assert(Slot->getType() == VTy);
+#endif
+  return Slot.get();
+}
+
 Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
   const fltSemantics &Semantics = Ty->getScalarType()->getFltSemantics();
   Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative));
@@ -1036,7 +1092,7 @@ Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
 
 ConstantFP::ConstantFP(Type *Ty, const APFloat &V)
     : ConstantData(Ty, ConstantFPVal), Val(V) {
-  assert(&V.getSemantics() == &Ty->getFltSemantics() &&
+  assert(&V.getSemantics() == &Ty->getScalarType()->getFltSemantics() &&
          "FP type Mismatch");
 }
 
@@ -1384,6 +1440,16 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
 
 Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
   if (!EC.isScalable()) {
+    // Maintain special handling of zero.
+    if (!V->isNullValue()) {
+      if (UseConstantIntForFixedLengthSplat && isa<ConstantInt>(V))
+        return ConstantInt::get(V->getContext(), EC,
+                                cast<ConstantInt>(V)->getValue());
+      if (UseConstantFPForFixedLengthSplat && isa<ConstantFP>(V))
+        return ConstantFP::get(V->getContext(), EC,
+                               cast<ConstantFP>(V)->getValue());
+    }
+
     // If this splat is compatible with ConstantDataVector, use it instead of
     // ConstantVector.
     if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) &&
@@ -1394,6 +1460,16 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
     return get(Elts);
   }
 
+  // Maintain special handling of zero.
+  if (!V->isNullValue()) {
+    if (UseConstantIntForScalableSplat && isa<ConstantInt>(V))
+      return ConstantInt::get(V->getContext(), EC,
+                              cast<ConstantInt>(V)->getValue());
+    if (UseConstantFPForScalableSplat && isa<ConstantFP>(V))
+      return ConstantFP::get(V->getContext(), EC,
+                             cast<ConstantFP>(V)->getValue());
+  }
+
   Type *VTy = VectorType::get(V->getType(), EC);
 
   if (V->isNullValue())
diff --git a/llvm/lib/IR/LLVMContextImpl.cpp b/llvm/lib/IR/LLVMContextImpl.cpp
index 15c90a4fe7b2ec..a0bf9cae7926bb 100644
--- a/llvm/lib/IR/LLVMContextImpl.cpp
+++ b/llvm/lib/IR/LLVMContextImpl.cpp
@@ -119,7 +119,9 @@ LLVMContextImpl::~LLVMContextImpl() {
   IntZeroConstants.clear();
   IntOneConstants.clear();
   IntConstants.clear();
+  IntSplatConstants.clear();
   FPConstants.clear();
+  FPSplatConstants.clear();
   CDSConstants.clear();
 
   // Destroy attribute node lists.
diff --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h
index 6a20291344989d..2ee1080a1ffa29 100644
--- a/llvm/lib/IR/LLVMContextImpl.h
+++ b/llvm/lib/IR/LLVMContextImpl.h
@@ -1488,8 +1488,12 @@ class LLVMContextImpl {
   DenseMap<unsigned, std::unique_ptr<ConstantInt>> IntZeroConstants;
   DenseMap<unsigned, std::unique_ptr<ConstantInt>> IntOneConstants;
   DenseMap<APInt, std::unique_ptr<ConstantInt>> IntConstants;
+  DenseMap<std::pair<ElementCount, APInt>, std::unique_ptr<ConstantInt>>
+      IntSplatConstants;
 
   DenseMap<APFloat, std::unique_ptr<ConstantFP>> FPConstants;
+  DenseMap<std::pair<ElementCount, APFloat>, std::unique_ptr<ConstantFP>>
+      FPSplatConstants;
 
   FoldingSet<AttributeImpl> AttrsSet;
   FoldingSet<AttributeListImpl> AttrsLists;
diff --git a/llvm/test/Bitcode/constant-splat.ll b/llvm/test/Bitcode/constant-splat.ll
new file mode 100644
index 00000000000000..d4921607d15b54
--- /dev/null
+++ b/llvm/test/Bitcode/constant-splat.ll
@@ -0,0 +1,61 @@
+; RUN: llvm-as -use-constant-int-for-fixed-length-splat \
+; RUN:         -use-constant-fp-for-fixed-length-splat \
+; RUN:         -use-constant-int-for-scalable-splat \
+; RUN:         -use-constant-fp-for-scalable-splat \
+; RUN:   < %s | llvm-dis -use-constant-int-for-fixed-length-splat \
+; RUN:                   -use-constant-fp-for-fixed-length-splat \
+; RUN:                   -use-constant-int-for-scalable-splat \
+; RUN:                   -use-constant-fp-for-scalable-splat \
+; RUN:   | FileCheck %s
+
+; CHECK: @constant.splat.i1 = constant <1 x i1> splat (i1 true)
+ at constant.splat.i1 = constant <1 x i1> splat (i1 true)
+
+; CHECK: @constant.splat.i32 = constant <5 x i32> splat (i32 7)
+ at constant.splat.i32 = constant <5 x i32> splat (i32 7)
+
+; CHECK: @constant.splat.i128 = constant <7 x i128> splat (i128 85070591730234615870450834276742070272)
+ at constant.splat.i128 = constant <7 x i128> splat (i128 85070591730234615870450834276742070272)
+
+; CHECK: @constant.splat.f16 = constant <2 x half> splat (half 0xHBC00)
+ at constant.splat.f16 = constant <2 x half> splat (half 0xHBC00)
+
+; CHECK: @constant.splat.f32 = constant <4 x float> splat (float -2.000000e+00)
+ at constant.splat.f32 = constant <4 x float> splat (float -2.000000e+00)
+
+; CHECK: @constant.splat.f64 = constant <6 x double> splat (double -3.000000e+00)
+ at constant.splat.f64 = constant <6 x double> splat (double -3.000000e+00)
+
+; CHECK: @constant.splat.128 = constant <8 x fp128> splat (fp128 0xL00000000000000018000000000000000)
+ at constant.splat.128 = constant <8 x fp128> splat (fp128 0xL00000000000000018000000000000000)
+
+; CHECK: @constant.splat.bf16 = constant <1 x bfloat> splat (bfloat 0xRC0A0)
+ at constant.splat.bf16 = constant <1 x bfloat> splat (bfloat 0xRC0A0)
+
+; CHECK: @constant.splat.x86_fp80 = constant <3 x x86_fp80> splat (x86_fp80 0xK4000C8F5C28F5C28F800)
+ at constant.splat.x86_fp80 = constant <3 x x86_fp80> splat (x86_fp80 0xK4000C8F5C28F5C28F800)
+
+; CHECK: @constant.splat.ppc_fp128 = constant <7 x ppc_fp128> splat (ppc_fp128 0xM80000000000000000000000000000000)
+ at constant.splat.ppc_fp128 = constant <7 x ppc_fp128> splat (ppc_fp128 0xM80000000000000000000000000000000)
+
+define void @add_fixed_lenth_vector_splat_i32(<4 x i32> %a) {
+; CHECK: %add = add <4 x i32> %a, splat (i32 137)
+  %add = add <4 x i32> %a, splat (i32 137)
+  ret void
+}
+
+define <4 x i32> @ret_fixed_lenth_vector_splat_i32() {
+; CHECK: ret <4 x i32> splat (i32 56)
+  ret <4 x i32> splat (i32 56)
+}
+
+define void @add_fixed_lenth_vector_splat_double(<vscale x 2 x double> %a) {
+; CHECK: %add = fadd <vscale x 2 x double> %a, splat (double 5.700000e+00)
+  %add = fadd <vscale x 2 x double> %a, splat (double 5.700000e+00)
+  ret void
+}
+
+define <vscale x 4 x i32> @ret_scalable_vector_splat_i32() {
+; CHECK: ret <vscale x 4 x i32> splat (i32 78)
+  ret <vscale x 4 x i32> splat (i32 78)
+}

>From 27b6edeea45dfb3be2ce4071a6d3e7833d3c7363 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Tue, 13 Feb 2024 14:05:17 +0000
Subject: [PATCH 2/2] Make ElementCount get interfaces private. Reduce repeated
 calls to getType().

---
 llvm/include/llvm/IR/Constants.h | 22 ++++++++++++++--------
 llvm/lib/IR/AsmWriter.cpp        | 18 +++++++++++-------
 2 files changed, 25 insertions(+), 15 deletions(-)

diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 39eec1b738fabb..c0ac9a4aa6750c 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -78,6 +78,7 @@ class ConstantData : public Constant {
 /// Class for constant integers.
 class ConstantInt final : public ConstantData {
   friend class Constant;
+  friend class ConstantVector;
 
   APInt Val;
 
@@ -85,6 +86,12 @@ class ConstantInt final : public ConstantData {
 
   void destroyConstantImpl();
 
+  /// Return a ConstantInt with the specified value and an implied Type. The
+  /// type is the vector type whose integer element type corresponds to the bit
+  /// width of the value.
+  static ConstantInt *get(LLVMContext &Context, ElementCount EC,
+                          const APInt &V);
+
 public:
   ConstantInt(const ConstantInt &) = delete;
 
@@ -123,12 +130,6 @@ class ConstantInt final : public ConstantData {
   /// type is the integer type that corresponds to the bit width of the value.
   static ConstantInt *get(LLVMContext &Context, const APInt &V);
 
-  /// Return a ConstantInt with the specified value and an implied Type. The
-  /// type is the vector type whose integer element type corresponds to the bit
-  /// width of the value.
-  static ConstantInt *get(LLVMContext &Context, ElementCount EC,
-                          const APInt &V);
-
   /// Return a ConstantInt constructed from the string strStart with the given
   /// radix.
   static ConstantInt *get(IntegerType *Ty, StringRef Str, uint8_t Radix);
@@ -265,6 +266,7 @@ class ConstantInt final : public ConstantData {
 ///
 class ConstantFP final : public ConstantData {
   friend class Constant;
+  friend class ConstantVector;
 
   APFloat Val;
 
@@ -272,6 +274,12 @@ class ConstantFP final : public ConstantData {
 
   void destroyConstantImpl();
 
+  /// Return a ConstantFP with the specified value and an implied Type. The
+  /// type is the vector type whose element type has the same floating point
+  /// semantics as the value.
+  static ConstantFP *get(LLVMContext &Context, ElementCount EC,
+                         const APFloat &V);
+
 public:
   ConstantFP(const ConstantFP &) = delete;
 
@@ -287,8 +295,6 @@ class ConstantFP final : public ConstantData {
 
   static Constant *get(Type *Ty, StringRef Str);
   static ConstantFP *get(LLVMContext &Context, const APFloat &V);
-  static ConstantFP *get(LLVMContext &Context, ElementCount EC,
-                         const APFloat &V);
   static Constant *getNaN(Type *Ty, bool Negative = false,
                           uint64_t Payload = 0);
   static Constant *getQNaN(Type *Ty, bool Negative = false,
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 1fcda6c384d96d..00cc14296e9b05 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1502,33 +1502,37 @@ static void WriteAPFloatInternal(raw_ostream &Out, const APFloat &APF) {
 static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
                                   AsmWriterContext &WriterCtx) {
   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) {
-    if (CI->getType()->isVectorTy()) {
+    Type *Ty = CI->getType();
+
+    if (Ty->isVectorTy()) {
       Out << "splat (";
-      WriterCtx.TypePrinter->print(CI->getType()->getScalarType(), Out);
+      WriterCtx.TypePrinter->print(Ty->getScalarType(), Out);
       Out << " ";
     }
 
-    if (CI->getType()->getScalarType()->isIntegerTy(1))
+    if (Ty->getScalarType()->isIntegerTy(1))
       Out << (CI->getZExtValue() ? "true" : "false");
     else
       Out << CI->getValue();
 
-    if (CI->getType()->isVectorTy())
+    if (Ty->isVectorTy())
       Out << ")";
 
     return;
   }
 
   if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CV)) {
-    if (CFP->getType()->isVectorTy()) {
+    Type *Ty = CFP->getType();
+
+    if (Ty->isVectorTy()) {
       Out << "splat (";
-      WriterCtx.TypePrinter->print(CFP->getType()->getScalarType(), Out);
+      WriterCtx.TypePrinter->print(Ty->getScalarType(), Out);
       Out << " ";
     }
 
     WriteAPFloatInternal(Out, CFP->getValueAPF());
 
-    if (CFP->getType()->isVectorTy())
+    if (Ty->isVectorTy())
       Out << ")";
 
     return;



More information about the llvm-commits mailing list