[llvm] d37fe26 - [NFC][IR] Type: add getWithNewType() method

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 22 14:52:22 PDT 2021


Author: Roman Lebedev
Date: 2021-03-23T00:50:58+03:00
New Revision: d37fe26a2bbf2f59e3fe3385ef7127086c9d7ee1

URL: https://github.com/llvm/llvm-project/commit/d37fe26a2bbf2f59e3fe3385ef7127086c9d7ee1
DIFF: https://github.com/llvm/llvm-project/commit/d37fe26a2bbf2f59e3fe3385ef7127086c9d7ee1.diff

LOG: [NFC][IR] Type: add getWithNewType() method

Sometimes you want to get a type with same vector element count
as the current type, but different element type,
but there's no QOL wrapper to do that. Add one.

Added: 
    

Modified: 
    llvm/include/llvm/IR/DerivedTypes.h
    llvm/include/llvm/IR/Type.h
    llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h
index 41814d5b50fe..edc59a8be55a 100644
--- a/llvm/include/llvm/IR/DerivedTypes.h
+++ b/llvm/include/llvm/IR/DerivedTypes.h
@@ -677,14 +677,17 @@ Type *Type::getExtendedType() const {
   return cast<IntegerType>(this)->getExtendedType();
 }
 
+Type *Type::getWithNewType(Type *EltTy) const {
+  if (auto *VTy = dyn_cast<VectorType>(this))
+    return VectorType::get(EltTy, VTy->getElementCount());
+  return EltTy;
+}
+
 Type *Type::getWithNewBitWidth(unsigned NewBitWidth) const {
   assert(
       isIntOrIntVectorTy() &&
       "Original type expected to be a vector of integers or a scalar integer.");
-  Type *NewType = getIntNTy(getContext(), NewBitWidth);
-  if (auto *VTy = dyn_cast<VectorType>(this))
-    NewType = VectorType::get(NewType, VTy->getElementCount());
-  return NewType;
+  return getWithNewType(getIntNTy(getContext(), NewBitWidth));
 }
 
 unsigned Type::getPointerAddressSpace() const {

diff  --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h
index 2b1f054b1118..e30d4ad22613 100644
--- a/llvm/include/llvm/IR/Type.h
+++ b/llvm/include/llvm/IR/Type.h
@@ -380,6 +380,11 @@ class Type {
     return ContainedTys[0];
   }
 
+  /// Given vector type, change the element type,
+  /// whilst keeping the old number of elements.
+  /// For non-vectors simply returns \p EltTy.
+  inline Type *getWithNewType(Type *EltTy) const;
+
   /// Given an integer or vector type, change the lane bitwidth to NewBitwidth,
   /// whilst keeping the old number of lanes.
   inline Type *getWithNewBitWidth(unsigned NewBitWidth) const;

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index ae72123e3f00..75621da20a5d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1927,11 +1927,8 @@ Instruction *InstCombinerImpl::visitIntToPtr(IntToPtrInst &CI) {
   unsigned AS = CI.getAddressSpace();
   if (CI.getOperand(0)->getType()->getScalarSizeInBits() !=
       DL.getPointerSizeInBits(AS)) {
-    Type *Ty = DL.getIntPtrType(CI.getContext(), AS);
-    // Handle vectors of pointers.
-    if (auto *CIVTy = dyn_cast<VectorType>(CI.getType()))
-      Ty = VectorType::get(Ty, CIVTy->getElementCount());
-
+    Type *Ty = CI.getOperand(0)->getType()->getWithNewType(
+        DL.getIntPtrType(CI.getContext(), AS));
     Value *P = Builder.CreateZExtOrTrunc(CI.getOperand(0), Ty);
     return new IntToPtrInst(P, CI.getType());
   }
@@ -1970,16 +1967,14 @@ Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) {
   // do a ptrtoint to intptr_t then do a trunc or zext.  This allows the cast
   // to be exposed to other transforms.
   Value *SrcOp = CI.getPointerOperand();
+  Type *SrcTy = SrcOp->getType();
   Type *Ty = CI.getType();
   unsigned AS = CI.getPointerAddressSpace();
   unsigned TySize = Ty->getScalarSizeInBits();
   unsigned PtrSize = DL.getPointerSizeInBits(AS);
   if (TySize != PtrSize) {
-    Type *IntPtrTy = DL.getIntPtrType(CI.getContext(), AS);
-    // Handle vectors of pointers.
-    if (auto *VecTy = dyn_cast<VectorType>(Ty))
-      IntPtrTy = VectorType::get(IntPtrTy, VecTy->getElementCount());
-
+    Type *IntPtrTy =
+        SrcTy->getWithNewType(DL.getIntPtrType(CI.getContext(), AS));
     Value *P = Builder.CreatePtrToInt(SrcOp, IntPtrTy);
     return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false);
   }


        


More information about the llvm-commits mailing list