[llvm] 0138399 - [InstCombine] Remove scalable vector restriction in InstCombineCasts

Jun Ma via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 17 06:03:24 PST 2020


Author: Jun Ma
Date: 2020-12-17T22:02:33+08:00
New Revision: 01383999037760288f617e24084991eaf6bd9272

URL: https://github.com/llvm/llvm-project/commit/01383999037760288f617e24084991eaf6bd9272
DIFF: https://github.com/llvm/llvm-project/commit/01383999037760288f617e24084991eaf6bd9272.diff

LOG: [InstCombine] Remove scalable vector restriction in InstCombineCasts

Differential Revision: https://reviews.llvm.org/D93389

Added: 
    

Modified: 
    llvm/lib/IR/Verifier.cpp
    llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
    llvm/test/Transforms/InstCombine/addrspacecast.ll
    llvm/test/Transforms/InstCombine/ptr-int-cast.ll
    llvm/test/Transforms/InstCombine/trunc-extractelement.ll
    llvm/test/Transforms/InstCombine/vec_shuffle.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 71bb94e77593..cac5df81661c 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2961,8 +2961,8 @@ void Verifier::visitAddrSpaceCastInst(AddrSpaceCastInst &I) {
   Assert(SrcTy->getPointerAddressSpace() != DestTy->getPointerAddressSpace(),
          "AddrSpaceCast must be between 
diff erent address spaces", &I);
   if (auto *SrcVTy = dyn_cast<VectorType>(SrcTy))
-    Assert(cast<FixedVectorType>(SrcVTy)->getNumElements() ==
-               cast<FixedVectorType>(DestTy)->getNumElements(),
+    Assert(SrcVTy->getElementCount() ==
+               cast<VectorType>(DestTy)->getElementCount(),
            "AddrSpaceCast vector pointer number of elements mismatch", &I);
   visitInstruction(I);
 }

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index e9ec8021e466..8750e83623e6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -907,20 +907,21 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
   Value *VecOp;
   ConstantInt *Cst;
   if (match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst))))) {
-    auto *VecOpTy = cast<FixedVectorType>(VecOp->getType());
-    unsigned VecNumElts = VecOpTy->getNumElements();
+    auto *VecOpTy = cast<VectorType>(VecOp->getType());
+    auto VecElts = VecOpTy->getElementCount();
 
     // A badly fit destination size would result in an invalid cast.
     if (SrcWidth % DestWidth == 0) {
       uint64_t TruncRatio = SrcWidth / DestWidth;
-      uint64_t BitCastNumElts = VecNumElts * TruncRatio;
+      uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio;
       uint64_t VecOpIdx = Cst->getZExtValue();
       uint64_t NewIdx = DL.isBigEndian() ? (VecOpIdx + 1) * TruncRatio - 1
                                          : VecOpIdx * TruncRatio;
       assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() &&
              "overflow 32-bits");
 
-      auto *BitCastTo = FixedVectorType::get(DestTy, BitCastNumElts);
+      auto *BitCastTo =
+          VectorType::get(DestTy, BitCastNumElts, VecElts.isScalable());
       Value *BitCast = Builder.CreateBitCast(VecOp, BitCastTo);
       return ExtractElementInst::Create(BitCast, Builder.getInt32(NewIdx));
     }
@@ -1974,12 +1975,9 @@ Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) {
   unsigned PtrSize = DL.getPointerSizeInBits(AS);
   if (TySize != PtrSize) {
     Type *IntPtrTy = DL.getIntPtrType(CI.getContext(), AS);
-    if (auto *VecTy = dyn_cast<VectorType>(Ty)) {
-      // Handle vectors of pointers.
-      // FIXME: what should happen for scalable vectors?
-      IntPtrTy = FixedVectorType::get(
-          IntPtrTy, cast<FixedVectorType>(VecTy)->getNumElements());
-    }
+    // Handle vectors of pointers.
+    if (auto *VecTy = dyn_cast<VectorType>(Ty))
+      IntPtrTy = VectorType::get(IntPtrTy, VecTy->getElementCount());
 
     Value *P = Builder.CreatePtrToInt(SrcOp, IntPtrTy);
     return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false);
@@ -2660,13 +2658,11 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) {
     // a bitcast to a vector with the same # elts.
     Value *ShufOp0 = Shuf->getOperand(0);
     Value *ShufOp1 = Shuf->getOperand(1);
-    unsigned NumShufElts =
-        cast<FixedVectorType>(Shuf->getType())->getNumElements();
-    unsigned NumSrcVecElts =
-        cast<FixedVectorType>(ShufOp0->getType())->getNumElements();
+    auto ShufElts = cast<VectorType>(Shuf->getType())->getElementCount();
+    auto SrcVecElts = cast<VectorType>(ShufOp0->getType())->getElementCount();
     if (Shuf->hasOneUse() && DestTy->isVectorTy() &&
-        cast<FixedVectorType>(DestTy)->getNumElements() == NumShufElts &&
-        NumShufElts == NumSrcVecElts) {
+        cast<VectorType>(DestTy)->getElementCount() == ShufElts &&
+        ShufElts == SrcVecElts) {
       BitCastInst *Tmp;
       // If either of the operands is a cast from CI.getType(), then
       // evaluating the shuffle in the casted destination's type will allow
@@ -2689,8 +2685,9 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) {
     // TODO: We should match the related pattern for bitreverse.
     if (DestTy->isIntegerTy() &&
         DL.isLegalInteger(DestTy->getScalarSizeInBits()) &&
-        SrcTy->getScalarSizeInBits() == 8 && NumShufElts % 2 == 0 &&
-        Shuf->hasOneUse() && Shuf->isReverse()) {
+        SrcTy->getScalarSizeInBits() == 8 &&
+        ShufElts.getKnownMinValue() % 2 == 0 && Shuf->hasOneUse() &&
+        Shuf->isReverse()) {
       assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask");
       assert(isa<UndefValue>(ShufOp1) && "Unexpected shuffle op");
       Function *Bswap =
@@ -2730,12 +2727,9 @@ Instruction *InstCombinerImpl::visitAddrSpaceCast(AddrSpaceCastInst &CI) {
   Type *DestElemTy = DestTy->getElementType();
   if (SrcTy->getElementType() != DestElemTy) {
     Type *MidTy = PointerType::get(DestElemTy, SrcTy->getAddressSpace());
-    if (VectorType *VT = dyn_cast<VectorType>(CI.getType())) {
-      // Handle vectors of pointers.
-      // FIXME: what should happen for scalable vectors?
-      MidTy = FixedVectorType::get(MidTy,
-                                   cast<FixedVectorType>(VT)->getNumElements());
-    }
+    // Handle vectors of pointers.
+    if (VectorType *VT = dyn_cast<VectorType>(CI.getType()))
+      MidTy = VectorType::get(MidTy, VT->getElementCount());
 
     Value *NewBitCast = Builder.CreateBitCast(Src, MidTy);
     return new AddrSpaceCastInst(NewBitCast, CI.getType());

diff  --git a/llvm/test/Transforms/InstCombine/addrspacecast.ll b/llvm/test/Transforms/InstCombine/addrspacecast.ll
index 2e34f61a6623..343f2b8e7b4f 100644
--- a/llvm/test/Transforms/InstCombine/addrspacecast.ll
+++ b/llvm/test/Transforms/InstCombine/addrspacecast.ll
@@ -102,6 +102,16 @@ define <4 x float addrspace(2)*> @combine_addrspacecast_types_vector(<4 x i32 ad
   ret <4 x float addrspace(2)*> %y
 }
 
+define <vscale x 4 x float addrspace(2)*> @combine_addrspacecast_types_scalevector(<vscale x 4 x i32 addrspace(1)*> %x) nounwind {
+; CHECK-LABEL: @combine_addrspacecast_types_scalevector(
+; CHECK-NEXT: bitcast <vscale x 4 x i32 addrspace(1)*> %x to <vscale x 4 x float addrspace(1)*>
+; CHECK-NEXT: addrspacecast <vscale x 4 x float addrspace(1)*> %1 to <vscale x 4 x float addrspace(2)*>
+; CHECK-NEXT: ret
+  %y = addrspacecast <vscale x 4 x i32 addrspace(1)*> %x to <vscale x 4 x float addrspace(2)*>
+  ret <vscale x 4 x float addrspace(2)*> %y
+}
+
+
 define i32 @canonicalize_addrspacecast([16 x i32] addrspace(1)* %arr) {
 ; CHECK-LABEL: @canonicalize_addrspacecast(
 ; CHECK-NEXT: getelementptr [16 x i32], [16 x i32] addrspace(1)* %arr, i32 0, i32 0

diff  --git a/llvm/test/Transforms/InstCombine/ptr-int-cast.ll b/llvm/test/Transforms/InstCombine/ptr-int-cast.ll
index 826c00484227..31040b8a377c 100644
--- a/llvm/test/Transforms/InstCombine/ptr-int-cast.ll
+++ b/llvm/test/Transforms/InstCombine/ptr-int-cast.ll
@@ -1,60 +1,89 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -instcombine -S | FileCheck %s
 target datalayout = "E-p:64:64:64-a0:0:8-f32:32:32-f64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-v64:64:64-v128:128:128"
 
 define i1 @test1(i32 *%x) nounwind {
+; CHECK-LABEL: @test1(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = ptrtoint i32* [[X:%.*]] to i64
+; CHECK-NEXT:    [[TMP1:%.*]] = and i64 [[TMP0]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i64 [[TMP1]], 0
+; CHECK-NEXT:    ret i1 [[TMP2]]
+;
 entry:
-; CHECK: test1
-; CHECK: ptrtoint i32* %x to i64
-	%0 = ptrtoint i32* %x to i1
-	ret i1 %0
+  %0 = ptrtoint i32* %x to i1
+  ret i1 %0
 }
 
 define i32* @test2(i128 %x) nounwind {
+; CHECK-LABEL: @test2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc i128 [[X:%.*]] to i64
+; CHECK-NEXT:    [[TMP1:%.*]] = inttoptr i64 [[TMP0]] to i32*
+; CHECK-NEXT:    ret i32* [[TMP1]]
+;
 entry:
-; CHECK: test2
-; CHECK: inttoptr i64 %0 to i32*
-	%0 = inttoptr i128 %x to i32*
-	ret i32* %0
+  %0 = inttoptr i128 %x to i32*
+  ret i32* %0
 }
 
 ; PR3574
-; CHECK: f0
-; CHECK: %1 = zext i32 %a0 to i64
-; CHECK: ret i64 %1
 define i64 @f0(i32 %a0) nounwind {
-       %t0 = inttoptr i32 %a0 to i8*
-       %t1 = ptrtoint i8* %t0 to i64
-       ret i64 %t1
+; CHECK-LABEL: @f0(
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i32 [[A0:%.*]] to i64
+; CHECK-NEXT:    ret i64 [[TMP1]]
+;
+  %t0 = inttoptr i32 %a0 to i8*
+  %t1 = ptrtoint i8* %t0 to i64
+  ret i64 %t1
 }
 
 define <4 x i32> @test4(<4 x i8*> %arg) nounwind {
 ; CHECK-LABEL: @test4(
-; CHECK: ptrtoint <4 x i8*> %arg to <4 x i64>
-; CHECK: trunc <4 x i64> %1 to <4 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = ptrtoint <4 x i8*> [[ARG:%.*]] to <4 x i64>
+; CHECK-NEXT:    [[P1:%.*]] = trunc <4 x i64> [[TMP1]] to <4 x i32>
+; CHECK-NEXT:    ret <4 x i32> [[P1]]
+;
   %p1 = ptrtoint <4 x i8*> %arg to <4 x i32>
   ret <4 x i32> %p1
 }
 
+define <vscale x 4 x i32> @testvscale4(<vscale x 4 x i8*> %arg) nounwind {
+; CHECK-LABEL: @testvscale4(
+; CHECK-NEXT:    [[TMP1:%.*]] = ptrtoint <vscale x 4 x i8*> [[ARG:%.*]] to <vscale x 4 x i64>
+; CHECK-NEXT:    [[P1:%.*]] = trunc <vscale x 4 x i64> [[TMP1]] to <vscale x 4 x i32>
+; CHECK-NEXT:    ret <vscale x 4 x i32> [[P1]]
+;
+  %p1 = ptrtoint <vscale x 4 x i8*> %arg to <vscale x 4 x i32>
+  ret <vscale x 4 x i32> %p1
+}
+
 define <4 x i128> @test5(<4 x i8*> %arg) nounwind {
 ; CHECK-LABEL: @test5(
-; CHECK: ptrtoint <4 x i8*> %arg to <4 x i64>
-; CHECK: zext <4 x i64> %1 to <4 x i128>
+; CHECK-NEXT:    [[TMP1:%.*]] = ptrtoint <4 x i8*> [[ARG:%.*]] to <4 x i64>
+; CHECK-NEXT:    [[P1:%.*]] = zext <4 x i64> [[TMP1]] to <4 x i128>
+; CHECK-NEXT:    ret <4 x i128> [[P1]]
+;
   %p1 = ptrtoint <4 x i8*> %arg to <4 x i128>
   ret <4 x i128> %p1
 }
 
 define <4 x i8*> @test6(<4 x i32> %arg) nounwind {
 ; CHECK-LABEL: @test6(
-; CHECK: zext <4 x i32> %arg to <4 x i64>
-; CHECK: inttoptr <4 x i64> %1 to <4 x i8*>
+; CHECK-NEXT:    [[TMP1:%.*]] = zext <4 x i32> [[ARG:%.*]] to <4 x i64>
+; CHECK-NEXT:    [[P1:%.*]] = inttoptr <4 x i64> [[TMP1]] to <4 x i8*>
+; CHECK-NEXT:    ret <4 x i8*> [[P1]]
+;
   %p1 = inttoptr <4 x i32> %arg to <4 x i8*>
   ret <4 x i8*> %p1
 }
 
 define <4 x i8*> @test7(<4 x i128> %arg) nounwind {
 ; CHECK-LABEL: @test7(
-; CHECK: trunc <4 x i128> %arg to <4 x i64>
-; CHECK: inttoptr <4 x i64> %1 to <4 x i8*>
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc <4 x i128> [[ARG:%.*]] to <4 x i64>
+; CHECK-NEXT:    [[P1:%.*]] = inttoptr <4 x i64> [[TMP1]] to <4 x i8*>
+; CHECK-NEXT:    ret <4 x i8*> [[P1]]
+;
   %p1 = inttoptr <4 x i128> %arg to <4 x i8*>
   ret <4 x i8*> %p1
 }

diff  --git a/llvm/test/Transforms/InstCombine/trunc-extractelement.ll b/llvm/test/Transforms/InstCombine/trunc-extractelement.ll
index c44577d1f1d1..1ffea4224a4c 100644
--- a/llvm/test/Transforms/InstCombine/trunc-extractelement.ll
+++ b/llvm/test/Transforms/InstCombine/trunc-extractelement.ll
@@ -18,6 +18,23 @@ define i32 @shrinkExtractElt_i64_to_i32_0(<3 x i64> %x) {
   ret i32 %t
 }
 
+define i32 @vscale_shrinkExtractElt_i64_to_i32_0(<vscale x 3 x i64> %x) {
+; LE-LABEL: @vscale_shrinkExtractElt_i64_to_i32_0(
+; LE-NEXT:    [[TMP1:%.*]] = bitcast <vscale x 3 x i64> [[X:%.*]] to <vscale x 6 x i32>
+; LE-NEXT:    [[T:%.*]] = extractelement <vscale x 6 x i32> [[TMP1]], i32 0
+; LE-NEXT:    ret i32 [[T]]
+;
+; BE-LABEL: @vscale_shrinkExtractElt_i64_to_i32_0(
+; BE-NEXT:    [[TMP1:%.*]] = bitcast <vscale x 3 x i64> [[X:%.*]] to <vscale x 6 x i32>
+; BE-NEXT:    [[T:%.*]] = extractelement <vscale x 6 x i32> [[TMP1]], i32 1
+; BE-NEXT:    ret i32 [[T]]
+;
+  %e = extractelement <vscale x 3 x i64> %x, i32 0
+  %t = trunc i64 %e to i32
+  ret i32 %t
+}
+
+
 define i32 @shrinkExtractElt_i64_to_i32_1(<3 x i64> %x) {
 ; LE-LABEL: @shrinkExtractElt_i64_to_i32_1(
 ; LE-NEXT:    [[TMP1:%.*]] = bitcast <3 x i64> [[X:%.*]] to <6 x i32>

diff  --git a/llvm/test/Transforms/InstCombine/vec_shuffle.ll b/llvm/test/Transforms/InstCombine/vec_shuffle.ll
index 4f712ca7b3fb..874189d505c8 100644
--- a/llvm/test/Transforms/InstCombine/vec_shuffle.ll
+++ b/llvm/test/Transforms/InstCombine/vec_shuffle.ll
@@ -59,6 +59,20 @@ define float @test6(<4 x float> %X) {
   ret float %r
 }
 
+define float @testvscale6(<vscale x 4 x float> %X) {
+; CHECK-LABEL: @testvscale6(
+; CHECK-NEXT:    [[T:%.*]] = shufflevector <vscale x 4 x float> [[X:%.*]], <vscale x 4 x float> undef, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x float> [[T]], i32 0
+; CHECK-NEXT:    ret float [[R]]
+;
+  %X1 = bitcast <vscale x 4 x float> %X to <vscale x 4 x i32>
+  %t = shufflevector <vscale x 4 x i32> %X1, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer
+  %t2 = bitcast <vscale x 4 x i32> %t to <vscale x 4 x float>
+  %r = extractelement <vscale x 4 x float> %t2, i32 0
+  ret float %r
+}
+
+
 define <4 x float> @test7(<4 x float> %x) {
 ; CHECK-LABEL: @test7(
 ; CHECK-NEXT:    [[R:%.*]] = shufflevector <4 x float> [[X:%.*]], <4 x float> undef, <4 x i32> <i32 0, i32 1, i32 undef, i32 undef>


        


More information about the llvm-commits mailing list