[llvm-branch-commits] [llvm] e12f584 - [InstCombine] Remove scalable vector restriction in InstCombineCompares

Jun Ma via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Dec 15 04:52:15 PST 2020


Author: Jun Ma
Date: 2020-12-15T20:36:57+08:00
New Revision: e12f584578006e877cc947cde17c8da86177e9cc

URL: https://github.com/llvm/llvm-project/commit/e12f584578006e877cc947cde17c8da86177e9cc
DIFF: https://github.com/llvm/llvm-project/commit/e12f584578006e877cc947cde17c8da86177e9cc.diff

LOG: [InstCombine] Remove scalable vector restriction in InstCombineCompares

Differential Revision: https://reviews.llvm.org/D93269

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/test/Transforms/InstCombine/vscale_cmp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index d6285dcd387d..139b04bb6a81 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -899,8 +899,8 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
     // For vectors, we apply the same reasoning on a per-lane basis.
     auto *Base = GEPLHS->getPointerOperand();
     if (GEPLHS->getType()->isVectorTy() && Base->getType()->isPointerTy()) {
-      int NumElts = cast<FixedVectorType>(GEPLHS->getType())->getNumElements();
-      Base = Builder.CreateVectorSplat(NumElts, Base);
+      auto EC = cast<VectorType>(GEPLHS->getType())->getElementCount();
+      Base = Builder.CreateVectorSplat(EC, Base);
     }
     return new ICmpInst(Cond, Base,
                         ConstantExpr::getPointerBitCastOrAddrSpaceCast(
@@ -1885,8 +1885,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp,
     if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) {
       Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1);
       if (auto *AndVTy = dyn_cast<VectorType>(And->getType()))
-        NTy = FixedVectorType::get(
-            NTy, cast<FixedVectorType>(AndVTy)->getNumElements());
+        NTy = VectorType::get(NTy, AndVTy->getElementCount());
       Value *Trunc = Builder.CreateTrunc(X, NTy);
       auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_SGE
                                                             : CmpInst::ICMP_SLT;
@@ -2192,8 +2191,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
       DL.isLegalInteger(TypeBits - Amt)) {
     Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt);
     if (auto *ShVTy = dyn_cast<VectorType>(ShType))
-      TruncTy = FixedVectorType::get(
-          TruncTy, cast<FixedVectorType>(ShVTy)->getNumElements());
+      TruncTy = VectorType::get(TruncTy, ShVTy->getElementCount());
     Constant *NewC =
         ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt));
     return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC);
@@ -2827,8 +2825,7 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp,
 
           Type *NewType = Builder.getIntNTy(XType->getScalarSizeInBits());
           if (auto *XVTy = dyn_cast<VectorType>(XType))
-            NewType = FixedVectorType::get(
-                NewType, cast<FixedVectorType>(XVTy)->getNumElements());
+            NewType = VectorType::get(NewType, XVTy->getElementCount());
           Value *NewBitcast = Builder.CreateBitCast(X, NewType);
           if (TrueIfSigned)
             return new ICmpInst(ICmpInst::ICMP_SLT, NewBitcast,
@@ -3411,8 +3408,8 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,
   // those elements by copying an existing, defined, and safe scalar constant.
   Type *OpTy = M->getType();
   auto *VecC = dyn_cast<Constant>(M);
-  if (OpTy->isVectorTy() && VecC && VecC->containsUndefElement()) {
-    auto *OpVTy = cast<FixedVectorType>(OpTy);
+  auto *OpVTy = dyn_cast<FixedVectorType>(OpTy);
+  if (OpVTy && VecC && VecC->containsUndefElement()) {
     Constant *SafeReplacementConstant = nullptr;
     for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) {
       if (!isa<UndefValue>(VecC->getAggregateElement(i))) {

diff  --git a/llvm/test/Transforms/InstCombine/vscale_cmp.ll b/llvm/test/Transforms/InstCombine/vscale_cmp.ll
index bbceab06e3fc..e7b8a2e3e3f2 100644
--- a/llvm/test/Transforms/InstCombine/vscale_cmp.ll
+++ b/llvm/test/Transforms/InstCombine/vscale_cmp.ll
@@ -9,3 +9,27 @@ define <vscale x 2 x i1> @sge(<vscale x 2 x i8> %x) {
   %cmp = icmp sge <vscale x 2 x i8> %x, zeroinitializer
   ret <vscale x 2 x i1> %cmp
 }
+
+define <vscale x 2 x i1> @gep_scalevector1(i32* %X) nounwind {
+; CHECK-LABEL: @gep_scalevector1(
+; CHECK-NEXT:    [[S:%.*]] = insertelement <vscale x 2 x i32*> undef, i32* [[X:%.*]], i32 0
+; CHECK-NEXT:    [[C:%.*]] = icmp eq <vscale x 2 x i32*> [[S]], zeroinitializer
+; CHECK-NEXT:    [[C1:%.*]] = shufflevector <vscale x 2 x i1> [[C]], <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer
+; CHECK-NEXT:    ret <vscale x 2 x i1> [[C1]]
+;
+  %A = getelementptr inbounds i32, i32* %X, <vscale x 2 x i64> zeroinitializer
+  %C = icmp eq <vscale x 2 x i32*> %A, zeroinitializer
+  ret <vscale x 2 x i1> %C
+}
+
+define <vscale x 2 x i1> @signbit_bitcast_fpext_scalevec(<vscale x 2 x half> %x) {
+; CHECK-LABEL: @signbit_bitcast_fpext_scalevec(
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <vscale x 2 x half> [[X:%.*]] to <vscale x 2 x i16>
+; CHECK-NEXT:    [[R:%.*]] = icmp slt <vscale x 2 x i16> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    ret <vscale x 2 x i1> [[R]]
+;
+  %f = fpext <vscale x 2 x half> %x to <vscale x 2 x float>
+  %b = bitcast <vscale x 2 x float> %f to <vscale x 2 x i32>
+  %r = icmp slt <vscale x 2 x i32> %b, zeroinitializer
+  ret <vscale x 2 x i1> %r
+}


        


More information about the llvm-branch-commits mailing list