[clang] [llvm] [LLVM][IR] Add native vector support to ConstantInt & ConstantFP. (PR #74502)

via cfe-commits cfe-commits at lists.llvm.org
Tue Dec 5 09:42:35 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: Paul Walker (paulwalker-arm)

<details>
<summary>Changes</summary>

[LLVM][IR] Add native vector support to ConstantInt & ConstantFP.

NOTE: For brevity the following takes about ConstantInt but
everything extends to cover ConstantFP as well.

Whilst ConstantInt::get() supports the creation of vectors whereby
each lane has the same value, it achieves this via other constants:

  * ConstantVector for fixed-length vectors
  * ConstantExprs for scalable vectors
    
ConstantExprs are being deprecated and ConstantVector is not space
efficient for larger vector types. This patch introduces an
alternative by allowing ConstantInt to natively support vector
splats via the IR syntax:

  <N x ty> splat(ty <imm>)

More specifically:

 * IR parsing is extended to support the new syntax.
 * ConstantInt gains the interface getSplat().
 * LLVMContext is extended to map <EC,APInt>->ConstantInt.
 * BitCodeReader/Writer is extended to support vector types.
    
Whilst this patch adds the base support, more work is required
before it's production ready. For example, there's likely to be
many places where isa<ConstantInt> assumes a scalar type. Accordingly
the default behaviour of ConstantInt::get() remains unchanged but a
set of flag are added to allow wider testing and thus help with the
migration:
    
  --use-constant-int-for-fixed-length-splat
  --use-constant-fp-for-fixed-length-splat
  --use-constant-int-for-scalable-splat
  --use-constant-fp-for-scalable-splat

NOTE: No change is required to the bitcode format because types and
values are handled separately.

NOTE: Code generation doesn't work out-the-box but the issues look
limited to calls to ConstantInt::getBitWidth() that will need to be
ported.

---

Patch is 41.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74502.diff


22 Files Affected:

- (modified) clang/lib/CodeGen/CGBuiltin.cpp (+4-3) 
- (modified) llvm/include/llvm/AsmParser/LLParser.h (+3-1) 
- (modified) llvm/include/llvm/AsmParser/LLToken.h (+1) 
- (modified) llvm/include/llvm/IR/Constants.h (+23-5) 
- (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+1-1) 
- (modified) llvm/lib/AsmParser/LLLexer.cpp (+1) 
- (modified) llvm/lib/AsmParser/LLParser.cpp (+50-6) 
- (modified) llvm/lib/Bitcode/Reader/BitcodeReader.cpp (+41-29) 
- (modified) llvm/lib/Bitcode/Writer/BitcodeWriter.cpp (+3-3) 
- (modified) llvm/lib/IR/AsmWriter.cpp (+32-5) 
- (modified) llvm/lib/IR/ConstantFold.cpp (+1-1) 
- (modified) llvm/lib/IR/Constants.cpp (+92-1) 
- (modified) llvm/lib/IR/LLVMContextImpl.cpp (+2) 
- (modified) llvm/lib/IR/LLVMContextImpl.h (+4) 
- (modified) llvm/lib/IR/Verifier.cpp (+2-2) 
- (modified) llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp (+3-3) 
- (modified) llvm/lib/Transforms/IPO/OpenMPOpt.cpp (+8-7) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp (+2-2) 
- (modified) llvm/lib/Transforms/Scalar/ConstantHoisting.cpp (+3-3) 
- (modified) llvm/lib/Transforms/Scalar/LoopFlatten.cpp (+1-1) 
- (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+4-4) 
- (added) llvm/test/Bitcode/constant-splat.ll (+53) 


``````````diff
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 65d9862621061..8dc828abf8aec 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -3218,7 +3218,7 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
     Value *AlignmentValue = EmitScalarExpr(E->getArg(1));
     ConstantInt *AlignmentCI = cast<ConstantInt>(AlignmentValue);
     if (AlignmentCI->getValue().ugt(llvm::Value::MaximumAlignment))
-      AlignmentCI = ConstantInt::get(AlignmentCI->getType(),
+      AlignmentCI = ConstantInt::get(AlignmentCI->getIntegerType(),
                                      llvm::Value::MaximumAlignment);
 
     emitAlignmentAssumption(PtrValue, Ptr,
@@ -17010,7 +17010,7 @@ Value *CodeGenFunction::EmitPPCBuiltinExpr(unsigned BuiltinID,
     Value *Op1 = EmitScalarExpr(E->getArg(1));
     ConstantInt *AlignmentCI = cast<ConstantInt>(Op0);
     if (AlignmentCI->getValue().ugt(llvm::Value::MaximumAlignment))
-      AlignmentCI = ConstantInt::get(AlignmentCI->getType(),
+      AlignmentCI = ConstantInt::get(AlignmentCI->getIntegerType(),
                                      llvm::Value::MaximumAlignment);
 
     emitAlignmentAssumption(Op1, E->getArg(1),
@@ -17248,7 +17248,8 @@ Value *CodeGenFunction::EmitPPCBuiltinExpr(unsigned BuiltinID,
         Op0, llvm::FixedVectorType::get(ConvertType(E->getType()), 2));
 
     if (getTarget().isLittleEndian())
-      Index = ConstantInt::get(Index->getType(), 1 - Index->getZExtValue());
+      Index =
+          ConstantInt::get(Index->getIntegerType(), 1 - Index->getZExtValue());
 
     return Builder.CreateExtractElement(Unpacked, Index);
   }
diff --git a/llvm/include/llvm/AsmParser/LLParser.h b/llvm/include/llvm/AsmParser/LLParser.h
index 810f3668d05d4..38f6f08b8f3a1 100644
--- a/llvm/include/llvm/AsmParser/LLParser.h
+++ b/llvm/include/llvm/AsmParser/LLParser.h
@@ -59,7 +59,9 @@ namespace llvm {
       t_Constant,                      // Value in ConstantVal.
       t_InlineAsm,                     // Value in FTy/StrVal/StrVal2/UIntVal.
       t_ConstantStruct,                // Value in ConstantStructElts.
-      t_PackedConstantStruct           // Value in ConstantStructElts.
+      t_PackedConstantStruct,          // Value in ConstantStructElts.
+      t_APSIntSplat,                   // Value in APSIntVal.
+      t_APFloatSplat                   // Value in APFloatVal.
     } Kind = t_LocalID;
 
     LLLexer::LocTy Loc;
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 0683291faae72..dd55afee21033 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -335,6 +335,7 @@ enum Kind {
   kw_extractelement,
   kw_insertelement,
   kw_shufflevector,
+  kw_splat,
   kw_extractvalue,
   kw_insertvalue,
   kw_blockaddress,
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 2f7fc5652c2cd..b76cb1beecf3c 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -81,6 +81,7 @@ class ConstantInt final : public ConstantData {
   APInt Val;
 
   ConstantInt(IntegerType *Ty, const APInt &V);
+  ConstantInt(VectorType *Ty, const APInt &V);
 
   void destroyConstantImpl();
 
@@ -98,6 +99,13 @@ class ConstantInt final : public ConstantData {
   /// value. Otherwise return a ConstantInt for the given value.
   static Constant *get(Type *Ty, uint64_t V, bool IsSigned = false);
 
+  /// WARNING: Incomplete support, do not use. These methods exist for early
+  /// prototyping, for most use cases ConstantInt::get() should be used.
+  /// Return a ConstantInt with a splat of the given value.
+  static ConstantInt *getSplat(LLVMContext &Context, ElementCount EC,
+                               const APInt &V);
+  static ConstantInt *getSplat(const VectorType *Ty, const APInt &V);
+
   /// Return a ConstantInt with the specified integer value for the specified
   /// type. If the type is wider than 64 bits, the value will be zero-extended
   /// to fit the type, unless IsSigned is true, in which case the value will
@@ -136,7 +144,11 @@ class ConstantInt final : public ConstantData {
   inline const APInt &getValue() const { return Val; }
 
   /// getBitWidth - Return the bitwidth of this constant.
-  unsigned getBitWidth() const { return Val.getBitWidth(); }
+  unsigned getBitWidth() const {
+    assert(getType()->isIntegerTy() &&
+           "Returning the bitwidth of a vector constant is not support!");
+    return Val.getBitWidth();
+  }
 
   /// Return the constant as a 64-bit unsigned integer value after it
   /// has been zero extended as appropriate for the type of this constant. Note
@@ -170,10 +182,9 @@ class ConstantInt final : public ConstantData {
   /// Determine if this constant's value is same as an unsigned char.
   bool equalsInt(uint64_t V) const { return Val == V; }
 
-  /// getType - Specialize the getType() method to always return an IntegerType,
-  /// which reduces the amount of casting needed in parts of the compiler.
-  ///
-  inline IntegerType *getType() const {
+  /// Variant of the getType() method to always return an IntegerType, which
+  /// reduces the amount of casting needed in parts of the compiler.
+  inline IntegerType *getIntegerType() const {
     return cast<IntegerType>(Value::getType());
   }
 
@@ -279,6 +290,13 @@ class ConstantFP final : public ConstantData {
   /// value. Otherwise return a ConstantFP for the given value.
   static Constant *get(Type *Ty, const APFloat &V);
 
+  /// WARNING: Incomplete support, do not use. These methods exist for early
+  /// prototyping, for most use cases ConstantFP::get() should be used.
+  /// Return a ConstantFP with a splat of the given value.
+  static ConstantFP *getSplat(LLVMContext &Context, ElementCount EC,
+                              const APFloat &V);
+  static ConstantFP *getSplat(const VectorType *Ty, const APFloat &V);
+
   static Constant *get(Type *Ty, StringRef Str);
   static ConstantFP *get(LLVMContext &Context, const APFloat &V);
   static Constant *getNaN(Type *Ty, bool Negative = false,
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index cef9f6ec179ba..c24bb1bb2cf9f 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -6081,7 +6081,7 @@ static Value *simplifyRelativeLoad(Constant *Ptr, Constant *Offset,
   Type *Int32Ty = Type::getInt32Ty(Ptr->getContext());
 
   auto *OffsetConstInt = dyn_cast<ConstantInt>(Offset);
-  if (!OffsetConstInt || OffsetConstInt->getType()->getBitWidth() > 64)
+  if (!OffsetConstInt || OffsetConstInt->getIntegerType()->getBitWidth() > 64)
     return nullptr;
 
   APInt OffsetInt = OffsetConstInt->getValue().sextOrTrunc(
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 09a205c445dbe..eb47284feb218 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -697,6 +697,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(uinc_wrap);
   KEYWORD(udec_wrap);
 
+  KEYWORD(splat);
   KEYWORD(vscale);
   KEYWORD(x);
   KEYWORD(blockaddress);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index d236b6cfa9000..94e1a51aa2e75 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -3952,6 +3952,31 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
     return false;
   }
 
+  case lltok::kw_splat: {
+    Lex.Lex();
+    if (parseToken(lltok::lparen, "expected '(' after vector splat"))
+      return true;
+    Constant *C;
+    if (parseGlobalTypeAndValue(C))
+      return true;
+    if (parseToken(lltok::rparen, "expected ')' at end of vector splat"))
+      return true;
+
+    if (auto *CI = dyn_cast<ConstantInt>(C)) {
+      ID.APSIntVal = CI->getValue();
+      ID.Kind = ValID::t_APSIntSplat;
+      return false;
+    }
+
+    if (auto *CFP = dyn_cast<ConstantFP>(C)) {
+      ID.APFloatVal = CFP->getValue();
+      ID.Kind = ValID::t_APFloatSplat;
+      return false;
+    }
+
+    return tokError("invalid splat operand");
+  }
+
   case lltok::kw_getelementptr:
   case lltok::kw_shufflevector:
   case lltok::kw_insertelement:
@@ -5716,9 +5741,23 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
     ID.APSIntVal = ID.APSIntVal.extOrTrunc(Ty->getPrimitiveSizeInBits());
     V = ConstantInt::get(Context, ID.APSIntVal);
     return false;
+  case ValID::t_APSIntSplat:
+    if (!Ty->isVectorTy() || !Ty->getScalarType()->isIntegerTy())
+      return error(ID.Loc, "expected an integer vector result");
+    if (ID.APSIntVal.getBitWidth() !=
+        cast<IntegerType>(Ty->getScalarType())->getBitWidth())
+      return error(ID.Loc, "operand type must match result element type");
+    V = ConstantInt::getSplat(cast<VectorType>(Ty), ID.APSIntVal);
+    return false;
   case ValID::t_APFloat:
-    if (!Ty->isFloatingPointTy() ||
-        !ConstantFP::isValueValidForType(Ty, ID.APFloatVal))
+  case ValID::t_APFloatSplat: {
+    if ((ID.Kind == ValID::t_APFloat && !Ty->isFloatingPointTy()) ||
+        (ID.Kind == ValID::t_APFloatSplat && !Ty->isVectorTy()))
+      return error(ID.Loc, "floating point constant invalid for type");
+
+    Type *ScalarTy = Ty->getScalarType();
+    if (!ScalarTy->isFloatingPointTy() ||
+        !ConstantFP::isValueValidForType(ScalarTy, ID.APFloatVal))
       return error(ID.Loc, "floating point constant invalid for type");
 
     // The lexer has no type info, so builds all half, bfloat, float, and double
@@ -5727,13 +5766,13 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
       // Check for signaling before potentially converting and losing that info.
       bool IsSNAN = ID.APFloatVal.isSignaling();
       bool Ignored;
-      if (Ty->isHalfTy())
+      if (ScalarTy->isHalfTy())
         ID.APFloatVal.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
                               &Ignored);
-      else if (Ty->isBFloatTy())
+      else if (ScalarTy->isBFloatTy())
         ID.APFloatVal.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven,
                               &Ignored);
-      else if (Ty->isFloatTy())
+      else if (ScalarTy->isFloatTy())
         ID.APFloatVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
                               &Ignored);
       if (IsSNAN) {
@@ -5745,13 +5784,18 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
                                          ID.APFloatVal.isNegative(), &Payload);
       }
     }
-    V = ConstantFP::get(Context, ID.APFloatVal);
+
+    if (auto *VTy = dyn_cast<VectorType>(Ty))
+      V = ConstantFP::getSplat(VTy, ID.APFloatVal);
+    else
+      V = ConstantFP::get(Context, ID.APFloatVal);
 
     if (V->getType() != Ty)
       return error(ID.Loc, "floating point constant does not have type '" +
                                getTypeString(Ty) + "'");
 
     return false;
+  }
   case ValID::t_Null:
     if (!Ty->isPointerTy())
       return error(ID.Loc, "null must be a pointer type");
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index e4c3770946b3a..b661d36fb6854 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -3022,50 +3022,62 @@ Error BitcodeReader::parseConstants() {
       V = Constant::getNullValue(CurTy);
       break;
     case bitc::CST_CODE_INTEGER:   // INTEGER: [intval]
-      if (!CurTy->isIntegerTy() || Record.empty())
+      if (!CurTy->isIntOrIntVectorTy() || Record.empty())
         return error("Invalid integer const record");
-      V = ConstantInt::get(CurTy, decodeSignRotatedValue(Record[0]));
+
+      if (auto *VTy = dyn_cast<VectorType>(CurTy)) {
+        auto *ScalarTy = cast<IntegerType>(VTy->getScalarType());
+        unsigned BitWidth = ScalarTy->getBitWidth();
+        APInt VInt(BitWidth, decodeSignRotatedValue(Record[0]));
+        V = ConstantInt::getSplat(VTy, VInt);
+      } else
+        V = ConstantInt::get(CurTy, decodeSignRotatedValue(Record[0]));
       break;
     case bitc::CST_CODE_WIDE_INTEGER: {// WIDE_INTEGER: [n x intval]
-      if (!CurTy->isIntegerTy() || Record.empty())
+      if (!CurTy->isIntOrIntVectorTy() || Record.empty())
         return error("Invalid wide integer const record");
 
-      APInt VInt =
-          readWideAPInt(Record, cast<IntegerType>(CurTy)->getBitWidth());
-      V = ConstantInt::get(Context, VInt);
-
+      auto *ScalarTy = cast<IntegerType>(CurTy->getScalarType());
+      APInt VInt = readWideAPInt(Record, ScalarTy->getBitWidth());
+      if (auto *VTy = dyn_cast<VectorType>(CurTy))
+        V = ConstantInt::getSplat(VTy, VInt);
+      else
+        V = ConstantInt::get(Context, VInt);
       break;
     }
     case bitc::CST_CODE_FLOAT: {    // FLOAT: [fpval]
       if (Record.empty())
         return error("Invalid float const record");
-      if (CurTy->isHalfTy())
-        V = ConstantFP::get(Context, APFloat(APFloat::IEEEhalf(),
-                                             APInt(16, (uint16_t)Record[0])));
-      else if (CurTy->isBFloatTy())
-        V = ConstantFP::get(Context, APFloat(APFloat::BFloat(),
-                                             APInt(16, (uint32_t)Record[0])));
-      else if (CurTy->isFloatTy())
-        V = ConstantFP::get(Context, APFloat(APFloat::IEEEsingle(),
-                                             APInt(32, (uint32_t)Record[0])));
-      else if (CurTy->isDoubleTy())
-        V = ConstantFP::get(Context, APFloat(APFloat::IEEEdouble(),
-                                             APInt(64, Record[0])));
-      else if (CurTy->isX86_FP80Ty()) {
+
+      APFloat Val(APFloat::Bogus());
+      auto *ScalarTy = CurTy->getScalarType();
+      if (ScalarTy->isHalfTy())
+        Val = APFloat(APFloat::IEEEhalf(), APInt(16, (uint16_t)Record[0]));
+      else if (ScalarTy->isBFloatTy())
+        Val = APFloat(APFloat::BFloat(), APInt(16, (uint32_t)Record[0]));
+      else if (ScalarTy->isFloatTy())
+        Val = APFloat(APFloat::IEEEsingle(), APInt(32, (uint32_t)Record[0]));
+      else if (ScalarTy->isDoubleTy())
+        Val = APFloat(APFloat::IEEEdouble(), APInt(64, Record[0]));
+      else if (ScalarTy->isX86_FP80Ty()) {
         // Bits are not stored the same way as a normal i80 APInt, compensate.
         uint64_t Rearrange[2];
         Rearrange[0] = (Record[1] & 0xffffLL) | (Record[0] << 16);
         Rearrange[1] = Record[0] >> 48;
-        V = ConstantFP::get(Context, APFloat(APFloat::x87DoubleExtended(),
-                                             APInt(80, Rearrange)));
-      } else if (CurTy->isFP128Ty())
-        V = ConstantFP::get(Context, APFloat(APFloat::IEEEquad(),
-                                             APInt(128, Record)));
-      else if (CurTy->isPPC_FP128Ty())
-        V = ConstantFP::get(Context, APFloat(APFloat::PPCDoubleDouble(),
-                                             APInt(128, Record)));
-      else
+        Val = APFloat(APFloat::x87DoubleExtended(), APInt(80, Rearrange));
+      } else if (ScalarTy->isFP128Ty())
+        Val = APFloat(APFloat::IEEEquad(), APInt(128, Record));
+      else if (ScalarTy->isPPC_FP128Ty())
+        Val = APFloat(APFloat::PPCDoubleDouble(), APInt(128, Record));
+      else {
         V = UndefValue::get(CurTy);
+        break;
+      }
+
+      if (auto *VTy = dyn_cast<VectorType>(CurTy))
+        V = ConstantFP::getSplat(VTy, Val);
+      else
+        V = ConstantFP::get(Context, Val);
       break;
     }
 
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 8239775d04865..0f5b9ff9ebd72 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -2577,18 +2577,18 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
     } else if (isa<UndefValue>(C)) {
       Code = bitc::CST_CODE_UNDEF;
     } else if (const ConstantInt *IV = dyn_cast<ConstantInt>(C)) {
-      if (IV->getBitWidth() <= 64) {
+      if (IV->getValue().getBitWidth() <= 64) {
         uint64_t V = IV->getSExtValue();
         emitSignedInt64(Record, V);
         Code = bitc::CST_CODE_INTEGER;
         AbbrevToUse = CONSTANTS_INTEGER_ABBREV;
-      } else {                             // Wide integers, > 64 bits in size.
+      } else { // Wide integers, > 64 bits in size.
         emitWideAPInt(Record, IV->getValue());
         Code = bitc::CST_CODE_WIDE_INTEGER;
       }
     } else if (const ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
       Code = bitc::CST_CODE_FLOAT;
-      Type *Ty = CFP->getType();
+      Type *Ty = CFP->getType()->getScalarType();
       if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() ||
           Ty->isDoubleTy()) {
         Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue());
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index fabc79adbd33d..e37da7c460f26 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1394,16 +1394,32 @@ static void WriteOptimizationInfo(raw_ostream &Out, const User *U) {
 static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
                                   AsmWriterContext &WriterCtx) {
   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) {
-    if (CI->getType()->isIntegerTy(1)) {
-      Out << (CI->getZExtValue() ? "true" : "false");
-      return;
+    if (CI->getType()->isVectorTy()) {
+      Out << "splat (";
+      WriterCtx.TypePrinter->print(CI->getType()->getScalarType(), Out);
+      Out << " ";
     }
-    Out << CI->getValue();
+
+    if (CI->getType()->getScalarType()->isIntegerTy(1))
+      Out << (CI->getZExtValue() ? "true" : "false");
+    else
+      Out << CI->getValue();
+
+    if (CI->getType()->isVectorTy())
+      Out << ")";
+
     return;
   }
 
   if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CV)) {
     const APFloat &APF = CFP->getValueAPF();
+
+    if (CFP->getType()->isVectorTy()) {
+      Out << "splat (";
+      WriterCtx.TypePrinter->print(CFP->getType()->getScalarType(), Out);
+      Out << " ";
+    }
+
     if (&APF.getSemantics() == &APFloat::IEEEsingle() ||
         &APF.getSemantics() == &APFloat::IEEEdouble()) {
       // We would like to output the FP constant value in exponential notation,
@@ -1429,6 +1445,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
         // Reparse stringized version!
         if (APFloat(APFloat::IEEEdouble(), StrVal).convertToDouble() == Val) {
           Out << StrVal;
+
+          if (CFP->getType()->isVectorTy())
+            Out << ")";
+
           return;
         }
       }
@@ -1454,6 +1474,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
         }
       }
       Out << format_hex(apf.bitcastToAPInt().getZExtValue(), 0, /*Upper=*/true);
+
+      if (CFP->getType()->isVectorTy())
+        Out << ")";
+
       return;
     }
 
@@ -1468,7 +1492,6 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
                                   /*Upper=*/true);
       Out << format_hex_no_prefix(API.getLoBits(64).getZExtValue(), 16,
                                   /*Upper=*/true);
-      return;
     } else if (&APF.getSemantics() == &APFloat::IEEEquad()) {
       Out << 'L';
       Out << format_hex_no_prefix(API.getLoBits(64).getZExtValue(), 16,
@@ -1491,6 +1514,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
                                   /*Upper=*/true);
     } else
       llvm_unreachable("Unsupported floating point type");
+
+    if (CFP->getType()->isVectorTy())
+      Out << ")";
+
     return;
   }
 
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index d499d74f7ba01..c478040234078 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -868,7 +868,7 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
           }
 
           if (GVAlign > 1) {
-            unsigned DstWidth = CI2->getType()->getBitWidth();
+            unsigned DstWidth = CI2->getIntegerType()->getBitWidth();
             unsigned SrcWidth = std::min(DstWidth, Log2(GVAlign));
             APInt BitsNotSet(APInt::...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/74502


More information about the cfe-commits mailing list