[llvm] dea16eb - [LLVM][IR] Replace ConstantInt's specialisation of getType() with getIntegerType(). (#75217)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 18 03:58:47 PST 2023


Author: Paul Walker
Date: 2023-12-18T11:58:42Z
New Revision: dea16ebd2613a4a218c53045270fc4fcc9b427ad

URL: https://github.com/llvm/llvm-project/commit/dea16ebd2613a4a218c53045270fc4fcc9b427ad
DIFF: https://github.com/llvm/llvm-project/commit/dea16ebd2613a4a218c53045270fc4fcc9b427ad.diff

LOG: [LLVM][IR] Replace ConstantInt's specialisation of getType() with getIntegerType(). (#75217)

The specialisation will not be valid when ConstantInt gains native
support for vector types.

This is largely a mechanical change but with extra attention paid to constant
folding, InstCombineVectorOps.cpp, LoopFlatten.cpp and Verifier.cpp to
remove the need to call `getIntegerType()`.

Co-authored-by: Nikita Popov <github at npopov.com>

Added: 
    

Modified: 
    clang/lib/CodeGen/CGBuiltin.cpp
    llvm/include/llvm/IR/Constants.h
    llvm/lib/Analysis/InstructionSimplify.cpp
    llvm/lib/IR/ConstantFold.cpp
    llvm/lib/IR/Verifier.cpp
    llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
    llvm/lib/Transforms/IPO/OpenMPOpt.cpp
    llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
    llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
    llvm/lib/Transforms/Scalar/LoopFlatten.cpp
    llvm/lib/Transforms/Utils/SimplifyCFG.cpp
    mlir/lib/Target/LLVMIR/ModuleImport.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index c42b885289c4c7..4eb1686f095062 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -3214,7 +3214,7 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
     Value *AlignmentValue = EmitScalarExpr(E->getArg(1));
     ConstantInt *AlignmentCI = cast<ConstantInt>(AlignmentValue);
     if (AlignmentCI->getValue().ugt(llvm::Value::MaximumAlignment))
-      AlignmentCI = ConstantInt::get(AlignmentCI->getType(),
+      AlignmentCI = ConstantInt::get(AlignmentCI->getIntegerType(),
                                      llvm::Value::MaximumAlignment);
 
     emitAlignmentAssumption(PtrValue, Ptr,
@@ -17034,7 +17034,7 @@ Value *CodeGenFunction::EmitPPCBuiltinExpr(unsigned BuiltinID,
     Value *Op1 = EmitScalarExpr(E->getArg(1));
     ConstantInt *AlignmentCI = cast<ConstantInt>(Op0);
     if (AlignmentCI->getValue().ugt(llvm::Value::MaximumAlignment))
-      AlignmentCI = ConstantInt::get(AlignmentCI->getType(),
+      AlignmentCI = ConstantInt::get(AlignmentCI->getIntegerType(),
                                      llvm::Value::MaximumAlignment);
 
     emitAlignmentAssumption(Op1, E->getArg(1),
@@ -17272,7 +17272,8 @@ Value *CodeGenFunction::EmitPPCBuiltinExpr(unsigned BuiltinID,
         Op0, llvm::FixedVectorType::get(ConvertType(E->getType()), 2));
 
     if (getTarget().isLittleEndian())
-      Index = ConstantInt::get(Index->getType(), 1 - Index->getZExtValue());
+      Index =
+          ConstantInt::get(Index->getIntegerType(), 1 - Index->getZExtValue());
 
     return Builder.CreateExtractElement(Unpacked, Index);
   }

diff  --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 0b9f89830b79c6..b5dcc7fbc1d929 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -171,10 +171,9 @@ class ConstantInt final : public ConstantData {
   /// Determine if this constant's value is same as an unsigned char.
   bool equalsInt(uint64_t V) const { return Val == V; }
 
-  /// getType - Specialize the getType() method to always return an IntegerType,
-  /// which reduces the amount of casting needed in parts of the compiler.
-  ///
-  inline IntegerType *getType() const {
+  /// Variant of the getType() method to always return an IntegerType, which
+  /// reduces the amount of casting needed in parts of the compiler.
+  inline IntegerType *getIntegerType() const {
     return cast<IntegerType>(Value::getType());
   }
 

diff  --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 2a45acf63aa2ca..5beac5547d65e0 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -6079,7 +6079,7 @@ static Value *simplifyRelativeLoad(Constant *Ptr, Constant *Offset,
   Type *Int32Ty = Type::getInt32Ty(Ptr->getContext());
 
   auto *OffsetConstInt = dyn_cast<ConstantInt>(Offset);
-  if (!OffsetConstInt || OffsetConstInt->getType()->getBitWidth() > 64)
+  if (!OffsetConstInt || OffsetConstInt->getBitWidth() > 64)
     return nullptr;
 
   APInt OffsetInt = OffsetConstInt->getValue().sextOrTrunc(

diff  --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index d499d74f7ba010..7fdc35e7fca097 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -868,7 +868,7 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
           }
 
           if (GVAlign > 1) {
-            unsigned DstWidth = CI2->getType()->getBitWidth();
+            unsigned DstWidth = CI2->getBitWidth();
             unsigned SrcWidth = std::min(DstWidth, Log2(GVAlign));
             APInt BitsNotSet(APInt::getLowBitsSet(DstWidth, SrcWidth));
 

diff  --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 8aba28026306a5..aeaca21a99cc5e 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2296,10 +2296,9 @@ void Verifier::verifyFunctionMetadata(
       Check(isa<ConstantAsMetadata>(MD->getOperand(0)),
             "expected a constant operand for !kcfi_type", MD);
       Constant *C = cast<ConstantAsMetadata>(MD->getOperand(0))->getValue();
-      Check(isa<ConstantInt>(C),
+      Check(isa<ConstantInt>(C) && isa<IntegerType>(C->getType()),
             "expected a constant integer operand for !kcfi_type", MD);
-      IntegerType *Type = cast<ConstantInt>(C)->getType();
-      Check(Type->getBitWidth() == 32,
+      Check(cast<ConstantInt>(C)->getBitWidth() == 32,
             "expected a 32-bit integer constant operand for !kcfi_type", MD);
     }
   }
@@ -5690,8 +5689,10 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
           "vector of ints");
 
     auto *Op3 = cast<ConstantInt>(Call.getArgOperand(2));
-    Check(Op3->getType()->getBitWidth() <= 32,
-          "third argument of [us][mul|div]_fix[_sat] must fit within 32 bits");
+    Check(Op3->getType()->isIntegerTy(),
+          "third operand of [us][mul|div]_fix[_sat] must be an int type");
+    Check(Op3->getBitWidth() <= 32,
+          "third operand of [us][mul|div]_fix[_sat] must fit within 32 bits");
 
     if (ID == Intrinsic::smul_fix || ID == Intrinsic::smul_fix_sat ||
         ID == Intrinsic::sdiv_fix || ID == Intrinsic::sdiv_fix_sat) {

diff  --git a/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp b/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
index 51ef72b873a516..7777ae23e8aecd 100644
--- a/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
@@ -1062,7 +1062,7 @@ void PolynomialMultiplyRecognize::promoteTo(Instruction *In,
   // Promote immediates.
   for (unsigned i = 0, n = In->getNumOperands(); i != n; ++i) {
     if (ConstantInt *CI = dyn_cast<ConstantInt>(In->getOperand(i)))
-      if (CI->getType()->getBitWidth() < DestBW)
+      if (CI->getBitWidth() < DestBW)
         In->setOperand(i, ConstantInt::get(DestTy, CI->getZExtValue()));
   }
 }
@@ -1577,7 +1577,7 @@ Value *PolynomialMultiplyRecognize::generate(BasicBlock::iterator At,
 
 static bool hasZeroSignBit(const Value *V) {
   if (const auto *CI = dyn_cast<const ConstantInt>(V))
-    return (CI->getType()->getSignBit() & CI->getSExtValue()) == 0;
+    return CI->getValue().isNonNegative();
   const Instruction *I = dyn_cast<const Instruction>(V);
   if (!I)
     return false;
@@ -1688,7 +1688,7 @@ void PolynomialMultiplyRecognize::setupPreSimplifier(Simplifier &S) {
       if (I->getOpcode() != Instruction::Or)
         return nullptr;
       ConstantInt *Msb = dyn_cast<ConstantInt>(I->getOperand(1));
-      if (!Msb || Msb->getZExtValue() != Msb->getType()->getSignBit())
+      if (!Msb || !Msb->getValue().isSignMask())
         return nullptr;
       if (!hasZeroSignBit(I->getOperand(0)))
         return nullptr;

diff  --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index b2665161c090df..2c880316e0a1cd 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -3763,7 +3763,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
     ConstantInt *ExecModeC =
         KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
     ConstantInt *AssumedExecModeC = ConstantInt::get(
-        ExecModeC->getType(),
+        ExecModeC->getIntegerType(),
         ExecModeC->getSExtValue() | OMP_TGT_EXEC_MODE_GENERIC_SPMD);
     if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
@@ -3792,7 +3792,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
     ConstantInt *MayUseNestedParallelismC =
         KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
     ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
-        MayUseNestedParallelismC->getType(), NestedParallelism);
+        MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
     setMayUseNestedParallelismOfKernelEnvironment(
         AssumedMayUseNestedParallelismC);
 
@@ -3801,7 +3801,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
           KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
               KernelEnvC);
       ConstantInt *AssumedUseGenericStateMachineC =
-          ConstantInt::get(UseGenericStateMachineC->getType(), false);
+          ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
       setUseGenericStateMachineOfKernelEnvironment(
           AssumedUseGenericStateMachineC);
     }
@@ -4280,8 +4280,9 @@ struct AAKernelInfoFunction : AAKernelInfo {
     // kernel is executed in.
     assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
            "Initially non-SPMD kernel has SPMD exec mode!");
-    setExecModeOfKernelEnvironment(ConstantInt::get(
-        ExecModeC->getType(), ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
+    setExecModeOfKernelEnvironment(
+        ConstantInt::get(ExecModeC->getIntegerType(),
+                         ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
 
     ++NumOpenMPTargetRegionKernelsSPMD;
 
@@ -4332,7 +4333,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
 
     // If not SPMD mode, indicate we use a custom state machine now.
     setUseGenericStateMachineOfKernelEnvironment(
-        ConstantInt::get(UseStateMachineC->getType(), false));
+        ConstantInt::get(UseStateMachineC->getIntegerType(), false));
 
     // If we don't actually need a state machine we are done here. This can
     // happen if there simply are no parallel regions. In the resulting kernel
@@ -4658,7 +4659,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
             KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
                 AA.KernelEnvC);
         ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
-            MayUseNestedParallelismC->getType(), AA.NestedParallelism);
+            MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
         AA.setMayUseNestedParallelismOfKernelEnvironment(
             NewMayUseNestedParallelismC);
       }

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index c381d25011f688..bd5f608045cf11 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -388,7 +388,7 @@ static APInt findDemandedEltsByAllUsers(Value *V) {
 /// arbitrarily pick 64 bit as our canonical type.  The actual bitwidth doesn't
 /// matter, we just want a consistent type to simplify CSE.
 static ConstantInt *getPreferredVectorIndex(ConstantInt *IndexC) {
-  const unsigned IndexBW = IndexC->getType()->getBitWidth();
+  const unsigned IndexBW = IndexC->getBitWidth();
   if (IndexBW == 64 || IndexC->getValue().getActiveBits() > 64)
     return nullptr;
   return ConstantInt::get(IndexC->getContext(),
@@ -2640,7 +2640,7 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf,
     assert(NewInsIndex != -1 && "Did not fold shuffle with unused operand?");
 
     // Index is updated to the potentially translated insertion lane.
-    IndexC = ConstantInt::get(IndexC->getType(), NewInsIndex);
+    IndexC = ConstantInt::get(IndexC->getIntegerType(), NewInsIndex);
     return true;
   };
 

diff  --git a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
index 1fb9d7fff32f66..9e40d94dd73c76 100644
--- a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
@@ -674,8 +674,7 @@ void ConstantHoistingPass::findBaseConstants(GlobalVariable *BaseGV) {
   llvm::stable_sort(ConstCandVec, [](const ConstantCandidate &LHS,
                                      const ConstantCandidate &RHS) {
     if (LHS.ConstInt->getType() != RHS.ConstInt->getType())
-      return LHS.ConstInt->getType()->getBitWidth() <
-             RHS.ConstInt->getType()->getBitWidth();
+      return LHS.ConstInt->getBitWidth() < RHS.ConstInt->getBitWidth();
     return LHS.ConstInt->getValue().ult(RHS.ConstInt->getValue());
   });
 
@@ -890,7 +889,7 @@ bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) {
         Type *Ty = ConstInfo.BaseExpr->getType();
         Base = new BitCastInst(ConstInfo.BaseExpr, Ty, "const", IP);
       } else {
-        IntegerType *Ty = ConstInfo.BaseInt->getType();
+        IntegerType *Ty = ConstInfo.BaseInt->getIntegerType();
         Base = new BitCastInst(ConstInfo.BaseInt, Ty, "const", IP);
       }
 

diff  --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index b1add3c42976fd..eef94636578d83 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -343,9 +343,8 @@ static bool verifyTripCount(Value *RHS, Loop *L,
     // If the RHS of the compare is equal to the backedge taken count we need
     // to add one to get the trip count.
     if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
-      ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
-      Value *NewRHS = ConstantInt::get(
-          ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
+      Value *NewRHS = ConstantInt::get(ConstantRHS->getContext(),
+                                       ConstantRHS->getValue() + 1);
       return setLoopComponents(NewRHS, TripCount, Increment,
                                IterationInstructions);
     }

diff  --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 89494a7f64971f..55e375670cc61e 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -6293,7 +6293,7 @@ Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) {
   }
   case BitMapKind: {
     // Type of the bitmap (e.g. i59).
-    IntegerType *MapTy = BitMap->getType();
+    IntegerType *MapTy = BitMap->getIntegerType();
 
     // Cast Index to the same type as the bitmap.
     // Note: The Index is <= the number of elements in the table, so
@@ -6668,7 +6668,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
   Value *TableIndex;
   ConstantInt *TableIndexOffset;
   if (UseSwitchConditionAsTableIndex) {
-    TableIndexOffset = ConstantInt::get(MaxCaseVal->getType(), 0);
+    TableIndexOffset = ConstantInt::get(MaxCaseVal->getIntegerType(), 0);
     TableIndex = SI->getCondition();
   } else {
     TableIndexOffset = MinCaseVal;
@@ -6752,7 +6752,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
     // Get the TableIndex'th bit of the bitmask.
     // If this bit is 0 (meaning hole) jump to the default destination,
     // else continue with table lookup.
-    IntegerType *MapTy = TableMask->getType();
+    IntegerType *MapTy = TableMask->getIntegerType();
     Value *MaskIndex =
         Builder.CreateZExtOrTrunc(TableIndex, MapTy, "switch.maskindex");
     Value *Shifted = Builder.CreateLShr(TableMask, MaskIndex, "switch.shifted");
@@ -6975,7 +6975,7 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
   // Replace each case with its trailing zeros number.
   for (auto &Case : SI->cases()) {
     auto *OrigValue = Case.getCaseValue();
-    Case.setValue(ConstantInt::get(OrigValue->getType(),
+    Case.setValue(ConstantInt::get(OrigValue->getIntegerType(),
                                    OrigValue->getValue().countr_zero()));
   }
 

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index ec2692f58695d0..905405e9398820 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -720,7 +720,7 @@ static TypedAttr getScalarConstantAsAttr(OpBuilder &builder,
   // Convert scalar intergers.
   if (auto *constInt = dyn_cast<llvm::ConstantInt>(constScalar)) {
     return builder.getIntegerAttr(
-        IntegerType::get(context, constInt->getType()->getBitWidth()),
+        IntegerType::get(context, constInt->getBitWidth()),
         constInt->getValue());
   }
 


        


More information about the llvm-commits mailing list