[Mlir-commits] [llvm] [clang] [mlir] [LLVM][IR] Replace ConstantInt's specialisation of getType() with getIntegerType(). (PR #75217)

Paul Walker llvmlistbot at llvm.org
Fri Dec 15 08:45:18 PST 2023


https://github.com/paulwalker-arm updated https://github.com/llvm/llvm-project/pull/75217

>From d19e9e20432c0dfe50bfba7cd782179653f42b2b Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Wed, 29 Nov 2023 14:45:06 +0000
Subject: [PATCH] [LLVM][IR] Replace ConstantInt's specialisation of getType()
 with getIntegerType().

The specialisation is no longer valid because ConstantInt is due to
gain native support for vector types.

Co-authored-by: Nikita Popov <github at npopov.com>
---
 clang/lib/CodeGen/CGBuiltin.cpp                   |  7 ++++---
 llvm/include/llvm/IR/Constants.h                  |  7 +++----
 llvm/lib/Analysis/InstructionSimplify.cpp         |  2 +-
 llvm/lib/IR/ConstantFold.cpp                      |  2 +-
 llvm/lib/IR/Verifier.cpp                          | 11 ++++++-----
 .../Hexagon/HexagonLoopIdiomRecognition.cpp       |  6 +++---
 llvm/lib/Transforms/IPO/OpenMPOpt.cpp             | 15 ++++++++-------
 .../InstCombine/InstCombineVectorOps.cpp          |  4 ++--
 llvm/lib/Transforms/Scalar/ConstantHoisting.cpp   |  5 ++---
 llvm/lib/Transforms/Scalar/LoopFlatten.cpp        |  5 ++---
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp         |  8 ++++----
 mlir/lib/Target/LLVMIR/ModuleImport.cpp           |  2 +-
 12 files changed, 37 insertions(+), 37 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 353b7930b3c1ea..f2a199c3f61d3c 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -3214,7 +3214,7 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
     Value *AlignmentValue = EmitScalarExpr(E->getArg(1));
     ConstantInt *AlignmentCI = cast<ConstantInt>(AlignmentValue);
     if (AlignmentCI->getValue().ugt(llvm::Value::MaximumAlignment))
-      AlignmentCI = ConstantInt::get(AlignmentCI->getType(),
+      AlignmentCI = ConstantInt::get(AlignmentCI->getIntegerType(),
                                      llvm::Value::MaximumAlignment);
 
     emitAlignmentAssumption(PtrValue, Ptr,
@@ -17027,7 +17027,7 @@ Value *CodeGenFunction::EmitPPCBuiltinExpr(unsigned BuiltinID,
     Value *Op1 = EmitScalarExpr(E->getArg(1));
     ConstantInt *AlignmentCI = cast<ConstantInt>(Op0);
     if (AlignmentCI->getValue().ugt(llvm::Value::MaximumAlignment))
-      AlignmentCI = ConstantInt::get(AlignmentCI->getType(),
+      AlignmentCI = ConstantInt::get(AlignmentCI->getIntegerType(),
                                      llvm::Value::MaximumAlignment);
 
     emitAlignmentAssumption(Op1, E->getArg(1),
@@ -17265,7 +17265,8 @@ Value *CodeGenFunction::EmitPPCBuiltinExpr(unsigned BuiltinID,
         Op0, llvm::FixedVectorType::get(ConvertType(E->getType()), 2));
 
     if (getTarget().isLittleEndian())
-      Index = ConstantInt::get(Index->getType(), 1 - Index->getZExtValue());
+      Index =
+          ConstantInt::get(Index->getIntegerType(), 1 - Index->getZExtValue());
 
     return Builder.CreateExtractElement(Unpacked, Index);
   }
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 0b9f89830b79c6..b5dcc7fbc1d929 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -171,10 +171,9 @@ class ConstantInt final : public ConstantData {
   /// Determine if this constant's value is same as an unsigned char.
   bool equalsInt(uint64_t V) const { return Val == V; }
 
-  /// getType - Specialize the getType() method to always return an IntegerType,
-  /// which reduces the amount of casting needed in parts of the compiler.
-  ///
-  inline IntegerType *getType() const {
+  /// Variant of the getType() method to always return an IntegerType, which
+  /// reduces the amount of casting needed in parts of the compiler.
+  inline IntegerType *getIntegerType() const {
     return cast<IntegerType>(Value::getType());
   }
 
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 2a45acf63aa2ca..5beac5547d65e0 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -6079,7 +6079,7 @@ static Value *simplifyRelativeLoad(Constant *Ptr, Constant *Offset,
   Type *Int32Ty = Type::getInt32Ty(Ptr->getContext());
 
   auto *OffsetConstInt = dyn_cast<ConstantInt>(Offset);
-  if (!OffsetConstInt || OffsetConstInt->getType()->getBitWidth() > 64)
+  if (!OffsetConstInt || OffsetConstInt->getBitWidth() > 64)
     return nullptr;
 
   APInt OffsetInt = OffsetConstInt->getValue().sextOrTrunc(
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index d499d74f7ba010..7fdc35e7fca097 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -868,7 +868,7 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
           }
 
           if (GVAlign > 1) {
-            unsigned DstWidth = CI2->getType()->getBitWidth();
+            unsigned DstWidth = CI2->getBitWidth();
             unsigned SrcWidth = std::min(DstWidth, Log2(GVAlign));
             APInt BitsNotSet(APInt::getLowBitsSet(DstWidth, SrcWidth));
 
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index cdc556ba7df820..404662fd2fd9ef 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2298,10 +2298,9 @@ void Verifier::verifyFunctionMetadata(
       Check(isa<ConstantAsMetadata>(MD->getOperand(0)),
             "expected a constant operand for !kcfi_type", MD);
       Constant *C = cast<ConstantAsMetadata>(MD->getOperand(0))->getValue();
-      Check(isa<ConstantInt>(C),
+      Check(isa<ConstantInt>(C) && isa<IntegerType>(C->getType()),
             "expected a constant integer operand for !kcfi_type", MD);
-      IntegerType *Type = cast<ConstantInt>(C)->getType();
-      Check(Type->getBitWidth() == 32,
+      Check(cast<ConstantInt>(C)->getBitWidth() == 32,
             "expected a 32-bit integer constant operand for !kcfi_type", MD);
     }
   }
@@ -5692,8 +5691,10 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
           "vector of ints");
 
     auto *Op3 = cast<ConstantInt>(Call.getArgOperand(2));
-    Check(Op3->getType()->getBitWidth() <= 32,
-          "third argument of [us][mul|div]_fix[_sat] must fit within 32 bits");
+    Check(Op3->getType()->isIntegerTy(),
+          "third operand of [us][mul|div]_fix[_sat] must be an int type");
+    Check(Op3->getBitWidth() <= 32,
+          "third operand of [us][mul|div]_fix[_sat] must fit within 32 bits");
 
     if (ID == Intrinsic::smul_fix || ID == Intrinsic::smul_fix_sat ||
         ID == Intrinsic::sdiv_fix || ID == Intrinsic::sdiv_fix_sat) {
diff --git a/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp b/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
index 51ef72b873a516..7777ae23e8aecd 100644
--- a/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
@@ -1062,7 +1062,7 @@ void PolynomialMultiplyRecognize::promoteTo(Instruction *In,
   // Promote immediates.
   for (unsigned i = 0, n = In->getNumOperands(); i != n; ++i) {
     if (ConstantInt *CI = dyn_cast<ConstantInt>(In->getOperand(i)))
-      if (CI->getType()->getBitWidth() < DestBW)
+      if (CI->getBitWidth() < DestBW)
         In->setOperand(i, ConstantInt::get(DestTy, CI->getZExtValue()));
   }
 }
@@ -1577,7 +1577,7 @@ Value *PolynomialMultiplyRecognize::generate(BasicBlock::iterator At,
 
 static bool hasZeroSignBit(const Value *V) {
   if (const auto *CI = dyn_cast<const ConstantInt>(V))
-    return (CI->getType()->getSignBit() & CI->getSExtValue()) == 0;
+    return CI->getValue().isNonNegative();
   const Instruction *I = dyn_cast<const Instruction>(V);
   if (!I)
     return false;
@@ -1688,7 +1688,7 @@ void PolynomialMultiplyRecognize::setupPreSimplifier(Simplifier &S) {
       if (I->getOpcode() != Instruction::Or)
         return nullptr;
       ConstantInt *Msb = dyn_cast<ConstantInt>(I->getOperand(1));
-      if (!Msb || Msb->getZExtValue() != Msb->getType()->getSignBit())
+      if (!Msb || !Msb->getValue().isSignMask())
         return nullptr;
       if (!hasZeroSignBit(I->getOperand(0)))
         return nullptr;
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index b2665161c090df..2c880316e0a1cd 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -3763,7 +3763,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
     ConstantInt *ExecModeC =
         KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
     ConstantInt *AssumedExecModeC = ConstantInt::get(
-        ExecModeC->getType(),
+        ExecModeC->getIntegerType(),
         ExecModeC->getSExtValue() | OMP_TGT_EXEC_MODE_GENERIC_SPMD);
     if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
@@ -3792,7 +3792,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
     ConstantInt *MayUseNestedParallelismC =
         KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
     ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
-        MayUseNestedParallelismC->getType(), NestedParallelism);
+        MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
     setMayUseNestedParallelismOfKernelEnvironment(
         AssumedMayUseNestedParallelismC);
 
@@ -3801,7 +3801,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
           KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
               KernelEnvC);
       ConstantInt *AssumedUseGenericStateMachineC =
-          ConstantInt::get(UseGenericStateMachineC->getType(), false);
+          ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
       setUseGenericStateMachineOfKernelEnvironment(
           AssumedUseGenericStateMachineC);
     }
@@ -4280,8 +4280,9 @@ struct AAKernelInfoFunction : AAKernelInfo {
     // kernel is executed in.
     assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
            "Initially non-SPMD kernel has SPMD exec mode!");
-    setExecModeOfKernelEnvironment(ConstantInt::get(
-        ExecModeC->getType(), ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
+    setExecModeOfKernelEnvironment(
+        ConstantInt::get(ExecModeC->getIntegerType(),
+                         ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
 
     ++NumOpenMPTargetRegionKernelsSPMD;
 
@@ -4332,7 +4333,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
 
     // If not SPMD mode, indicate we use a custom state machine now.
     setUseGenericStateMachineOfKernelEnvironment(
-        ConstantInt::get(UseStateMachineC->getType(), false));
+        ConstantInt::get(UseStateMachineC->getIntegerType(), false));
 
     // If we don't actually need a state machine we are done here. This can
     // happen if there simply are no parallel regions. In the resulting kernel
@@ -4658,7 +4659,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
             KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
                 AA.KernelEnvC);
         ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
-            MayUseNestedParallelismC->getType(), AA.NestedParallelism);
+            MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
         AA.setMayUseNestedParallelismOfKernelEnvironment(
             NewMayUseNestedParallelismC);
       }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index c8b58c51d4e6ec..9a2c2483a0379c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -388,7 +388,7 @@ static APInt findDemandedEltsByAllUsers(Value *V) {
 /// arbitrarily pick 64 bit as our canonical type.  The actual bitwidth doesn't
 /// matter, we just want a consistent type to simplify CSE.
 static ConstantInt *getPreferredVectorIndex(ConstantInt *IndexC) {
-  const unsigned IndexBW = IndexC->getType()->getBitWidth();
+  const unsigned IndexBW = IndexC->getBitWidth();
   if (IndexBW == 64 || IndexC->getValue().getActiveBits() > 64)
     return nullptr;
   return ConstantInt::get(IndexC->getContext(),
@@ -2639,7 +2639,7 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf,
     assert(NewInsIndex != -1 && "Did not fold shuffle with unused operand?");
 
     // Index is updated to the potentially translated insertion lane.
-    IndexC = ConstantInt::get(IndexC->getType(), NewInsIndex);
+    IndexC = ConstantInt::get(IndexC->getIntegerType(), NewInsIndex);
     return true;
   };
 
diff --git a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
index 1fb9d7fff32f66..9e40d94dd73c76 100644
--- a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp
@@ -674,8 +674,7 @@ void ConstantHoistingPass::findBaseConstants(GlobalVariable *BaseGV) {
   llvm::stable_sort(ConstCandVec, [](const ConstantCandidate &LHS,
                                      const ConstantCandidate &RHS) {
     if (LHS.ConstInt->getType() != RHS.ConstInt->getType())
-      return LHS.ConstInt->getType()->getBitWidth() <
-             RHS.ConstInt->getType()->getBitWidth();
+      return LHS.ConstInt->getBitWidth() < RHS.ConstInt->getBitWidth();
     return LHS.ConstInt->getValue().ult(RHS.ConstInt->getValue());
   });
 
@@ -890,7 +889,7 @@ bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) {
         Type *Ty = ConstInfo.BaseExpr->getType();
         Base = new BitCastInst(ConstInfo.BaseExpr, Ty, "const", IP);
       } else {
-        IntegerType *Ty = ConstInfo.BaseInt->getType();
+        IntegerType *Ty = ConstInfo.BaseInt->getIntegerType();
         Base = new BitCastInst(ConstInfo.BaseInt, Ty, "const", IP);
       }
 
diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index b1add3c42976fd..eef94636578d83 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -343,9 +343,8 @@ static bool verifyTripCount(Value *RHS, Loop *L,
     // If the RHS of the compare is equal to the backedge taken count we need
     // to add one to get the trip count.
     if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
-      ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
-      Value *NewRHS = ConstantInt::get(
-          ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue());
+      Value *NewRHS = ConstantInt::get(ConstantRHS->getContext(),
+                                       ConstantRHS->getValue() + 1);
       return setLoopComponents(NewRHS, TripCount, Increment,
                                IterationInstructions);
     }
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 89494a7f64971f..55e375670cc61e 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -6293,7 +6293,7 @@ Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) {
   }
   case BitMapKind: {
     // Type of the bitmap (e.g. i59).
-    IntegerType *MapTy = BitMap->getType();
+    IntegerType *MapTy = BitMap->getIntegerType();
 
     // Cast Index to the same type as the bitmap.
     // Note: The Index is <= the number of elements in the table, so
@@ -6668,7 +6668,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
   Value *TableIndex;
   ConstantInt *TableIndexOffset;
   if (UseSwitchConditionAsTableIndex) {
-    TableIndexOffset = ConstantInt::get(MaxCaseVal->getType(), 0);
+    TableIndexOffset = ConstantInt::get(MaxCaseVal->getIntegerType(), 0);
     TableIndex = SI->getCondition();
   } else {
     TableIndexOffset = MinCaseVal;
@@ -6752,7 +6752,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
     // Get the TableIndex'th bit of the bitmask.
     // If this bit is 0 (meaning hole) jump to the default destination,
     // else continue with table lookup.
-    IntegerType *MapTy = TableMask->getType();
+    IntegerType *MapTy = TableMask->getIntegerType();
     Value *MaskIndex =
         Builder.CreateZExtOrTrunc(TableIndex, MapTy, "switch.maskindex");
     Value *Shifted = Builder.CreateLShr(TableMask, MaskIndex, "switch.shifted");
@@ -6975,7 +6975,7 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
   // Replace each case with its trailing zeros number.
   for (auto &Case : SI->cases()) {
     auto *OrigValue = Case.getCaseValue();
-    Case.setValue(ConstantInt::get(OrigValue->getType(),
+    Case.setValue(ConstantInt::get(OrigValue->getIntegerType(),
                                    OrigValue->getValue().countr_zero()));
   }
 
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index ec2692f58695d0..905405e9398820 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -720,7 +720,7 @@ static TypedAttr getScalarConstantAsAttr(OpBuilder &builder,
   // Convert scalar intergers.
   if (auto *constInt = dyn_cast<llvm::ConstantInt>(constScalar)) {
     return builder.getIntegerAttr(
-        IntegerType::get(context, constInt->getType()->getBitWidth()),
+        IntegerType::get(context, constInt->getBitWidth()),
         constInt->getValue());
   }
 



More information about the Mlir-commits mailing list