[llvm] 8819202 - [SVE] Eliminate bad VectorType::getNumElements() calls from ConstantFold

Christopher Tetreault via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 17 14:20:13 PDT 2020


Author: Christopher Tetreault
Date: 2020-06-17T14:19:56-07:00
New Revision: 8819202dfd2c39a7ed4dd69f0d7e0e0bcf409e2a

URL: https://github.com/llvm/llvm-project/commit/8819202dfd2c39a7ed4dd69f0d7e0e0bcf409e2a
DIFF: https://github.com/llvm/llvm-project/commit/8819202dfd2c39a7ed4dd69f0d7e0e0bcf409e2a.diff

LOG: [SVE] Eliminate bad VectorType::getNumElements() calls from ConstantFold

Summary:
Assume all usages of this function are explicitly fixed-width operations
and cast to FixedVectorType

Reviewers: efriedma, sdesmalen, c-rhodes, majnemer, dblaikie

Reviewed By: sdesmalen

Subscribers: tschuett, hiraditya, rkruppe, psnobl, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D80262

Added: 
    llvm/test/Analysis/ConstantFolding/extractelement-vscale.ll

Modified: 
    llvm/lib/IR/ConstantFold.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index 3fb49e94870f..ef584afc68bc 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -55,8 +55,8 @@ static Constant *BitCastConstantVector(Constant *CV, VectorType *DstTy) {
   // If this cast changes element count then we can't handle it here:
   // doing so requires endianness information.  This should be handled by
   // Analysis/ConstantFolding.cpp
-  unsigned NumElts = DstTy->getNumElements();
-  if (NumElts != cast<VectorType>(CV->getType())->getNumElements())
+  unsigned NumElts = cast<FixedVectorType>(DstTy)->getNumElements();
+  if (NumElts != cast<FixedVectorType>(CV->getType())->getNumElements())
     return nullptr;
 
   Type *DstEltTy = DstTy->getElementType();
@@ -573,8 +573,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
   // count may be mismatched; don't attempt to handle that here.
   if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) &&
       DestTy->isVectorTy() &&
-      cast<VectorType>(DestTy)->getNumElements() ==
-          cast<VectorType>(V->getType())->getNumElements()) {
+      cast<FixedVectorType>(DestTy)->getNumElements() ==
+          cast<FixedVectorType>(V->getType())->getNumElements()) {
     VectorType *DestVecTy = cast<VectorType>(DestTy);
     Type *DstEltTy = DestVecTy->getElementType();
     // Fast path for splatted constants.
@@ -585,7 +585,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
     }
     SmallVector<Constant *, 16> res;
     Type *Ty = IntegerType::get(V->getContext(), 32);
-    for (unsigned i = 0, e = cast<VectorType>(V->getType())->getNumElements();
+    for (unsigned i = 0,
+                  e = cast<FixedVectorType>(V->getType())->getNumElements();
          i != e; ++i) {
       Constant *C =
         ConstantExpr::getExtractElement(V, ConstantInt::get(Ty, i));
@@ -809,9 +810,11 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val,
   if (!CIdx)
     return nullptr;
 
-  // ee({w,x,y,z}, wrong_value) -> undef
-  if (CIdx->uge(ValVTy->getNumElements()))
-    return UndefValue::get(ValVTy->getElementType());
+  if (auto *ValFVTy = dyn_cast<FixedVectorType>(Val->getType())) {
+    // ee({w,x,y,z}, wrong_value) -> undef
+    if (CIdx->uge(ValFVTy->getNumElements()))
+      return UndefValue::get(ValFVTy->getElementType());
+  }
 
   // ee (gep (ptr, idx0, ...), idx) -> gep (ee (ptr, idx), ee (idx0, idx), ...)
   if (auto *CE = dyn_cast<ConstantExpr>(Val)) {
@@ -823,7 +826,7 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val,
         if (Op->getType()->isVectorTy()) {
           Constant *ScalarOp = ConstantExpr::getExtractElement(Op, Idx);
           if (!ScalarOp)
-            return  nullptr;
+            return nullptr;
           Ops.push_back(ScalarOp);
         } else
           Ops.push_back(Op);
@@ -833,6 +836,16 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val,
     }
   }
 
+  // CAZ of type ScalableVectorType and n < CAZ->getMinNumElements() =>
+  //   extractelt CAZ, n -> 0
+  if (auto *ValSVTy = dyn_cast<ScalableVectorType>(Val->getType())) {
+    if (!CIdx->uge(ValSVTy->getMinNumElements())) {
+      if (auto *CAZ = dyn_cast<ConstantAggregateZero>(Val))
+        return CAZ->getElementValue(CIdx->getZExtValue());
+    }
+    return nullptr;
+  }
+
   return Val->getAggregateElement(CIdx);
 }
 
@@ -847,11 +860,12 @@ Constant *llvm::ConstantFoldInsertElementInstruction(Constant *Val,
 
   // Do not iterate on scalable vector. The num of elements is unknown at
   // compile-time.
-  VectorType *ValTy = cast<VectorType>(Val->getType());
-  if (isa<ScalableVectorType>(ValTy))
+  if (isa<ScalableVectorType>(Val->getType()))
     return nullptr;
 
-  unsigned NumElts = cast<VectorType>(Val->getType())->getNumElements();
+  auto *ValTy = cast<FixedVectorType>(Val->getType());
+
+  unsigned NumElts = ValTy->getNumElements();
   if (CIdx->uge(NumElts))
     return UndefValue::get(Val->getType());
 
@@ -898,7 +912,7 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1, Constant *V2,
   if (isa<ScalableVectorType>(V1VTy))
     return nullptr;
 
-  unsigned SrcNumElts = V1VTy->getNumElements();
+  unsigned SrcNumElts = V1VTy->getElementCount().Min;
 
   // Loop over the shuffle mask, evaluating each element.
   SmallVector<Constant*, 32> Result;
@@ -998,11 +1012,8 @@ Constant *llvm::ConstantFoldUnaryInstruction(unsigned Opcode, Constant *C) {
     case Instruction::FNeg:
       return ConstantFP::get(C->getContext(), neg(CV));
     }
-  } else if (VectorType *VTy = dyn_cast<VectorType>(C->getType())) {
-    // Do not iterate on scalable vector. The number of elements is unknown at
-    // compile-time.
-    if (IsScalableVector)
-      return nullptr;
+  } else if (auto *VTy = dyn_cast<FixedVectorType>(C->getType())) {
+
     Type *Ty = IntegerType::get(VTy->getContext(), 32);
     // Fast path for splatted constants.
     if (Constant *Splat = C->getSplatValue()) {
@@ -1011,7 +1022,7 @@ Constant *llvm::ConstantFoldUnaryInstruction(unsigned Opcode, Constant *C) {
     }
 
     // Fold each element and create a vector constant from those constants.
-    SmallVector<Constant*, 16> Result;
+    SmallVector<Constant *, 16> Result;
     for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) {
       Constant *ExtractIdx = ConstantInt::get(Ty, i);
       Constant *Elt = ConstantExpr::getExtractElement(C, ExtractIdx);
@@ -1367,11 +1378,12 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
         return ConstantFP::get(C1->getContext(), C3V);
       }
     }
-  } else if (VectorType *VTy = dyn_cast<VectorType>(C1->getType())) {
+  } else if (IsScalableVector) {
     // Do not iterate on scalable vector. The number of elements is unknown at
     // compile-time.
-    if (IsScalableVector)
-      return nullptr;
+    // FIXME: this branch can potentially be removed
+    return nullptr;
+  } else if (auto *VTy = dyn_cast<FixedVectorType>(C1->getType())) {
     // Fast path for splatted constants.
     if (Constant *C2Splat = C2->getSplatValue()) {
       if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue())
@@ -2014,7 +2026,7 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
     SmallVector<Constant*, 4> ResElts;
     Type *Ty = IntegerType::get(C1->getContext(), 32);
     // Compare the elements, producing an i1 result or constant expr.
-    for (unsigned i = 0, e = C1VTy->getNumElements(); i != e; ++i) {
+    for (unsigned i = 0, e = C1VTy->getElementCount().Min; i != e; ++i) {
       Constant *C1E =
         ConstantExpr::getExtractElement(C1, ConstantInt::get(Ty, i));
       Constant *C2E =
@@ -2286,14 +2298,18 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C,
       assert(Ty && "Invalid indices for GEP!");
       Type *OrigGEPTy = PointerType::get(Ty, PtrTy->getAddressSpace());
       Type *GEPTy = PointerType::get(Ty, PtrTy->getAddressSpace());
-      if (VectorType *VT = dyn_cast<VectorType>(C->getType()))
-        GEPTy = FixedVectorType::get(OrigGEPTy, VT->getNumElements());
-
+      if (VectorType *VT = dyn_cast<VectorType>(C->getType())) {
+        // FIXME: handle scalable vectors (use getElementCount())
+        GEPTy = FixedVectorType::get(
+            OrigGEPTy, cast<FixedVectorType>(VT)->getNumElements());
+      }
       // The GEP returns a vector of pointers when one of more of
       // its arguments is a vector.
       for (unsigned i = 0, e = Idxs.size(); i != e; ++i) {
         if (auto *VT = dyn_cast<VectorType>(Idxs[i]->getType())) {
-          GEPTy = FixedVectorType::get(OrigGEPTy, VT->getNumElements());
+          // FIXME: handle scalable vectors
+          GEPTy = FixedVectorType::get(
+              OrigGEPTy, cast<FixedVectorType>(VT)->getNumElements());
           break;
         }
       }
@@ -2500,19 +2516,19 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C,
 
     if (!IsCurrIdxVector && IsPrevIdxVector)
       CurrIdx = ConstantDataVector::getSplat(
-          cast<VectorType>(PrevIdx->getType())->getNumElements(), CurrIdx);
+          cast<FixedVectorType>(PrevIdx->getType())->getNumElements(), CurrIdx);
 
     if (!IsPrevIdxVector && IsCurrIdxVector)
       PrevIdx = ConstantDataVector::getSplat(
-          cast<VectorType>(CurrIdx->getType())->getNumElements(), PrevIdx);
+          cast<FixedVectorType>(CurrIdx->getType())->getNumElements(), PrevIdx);
 
     Constant *Factor =
         ConstantInt::get(CurrIdx->getType()->getScalarType(), NumElements);
     if (UseVector)
       Factor = ConstantDataVector::getSplat(
           IsPrevIdxVector
-              ? cast<VectorType>(PrevIdx->getType())->getNumElements()
-              : cast<VectorType>(CurrIdx->getType())->getNumElements(),
+              ? cast<FixedVectorType>(PrevIdx->getType())->getNumElements()
+              : cast<FixedVectorType>(CurrIdx->getType())->getNumElements(),
           Factor);
 
     NewIdxs[i] = ConstantExpr::getSRem(CurrIdx, Factor);
@@ -2531,8 +2547,8 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C,
       ExtendedTy = FixedVectorType::get(
           ExtendedTy,
           IsPrevIdxVector
-              ? cast<VectorType>(PrevIdx->getType())->getNumElements()
-              : cast<VectorType>(CurrIdx->getType())->getNumElements());
+              ? cast<FixedVectorType>(PrevIdx->getType())->getNumElements()
+              : cast<FixedVectorType>(CurrIdx->getType())->getNumElements());
 
     if (!PrevIdx->getType()->isIntOrIntVectorTy(CommonExtendedWidth))
       PrevIdx = ConstantExpr::getSExt(PrevIdx, ExtendedTy);

diff  --git a/llvm/test/Analysis/ConstantFolding/extractelement-vscale.ll b/llvm/test/Analysis/ConstantFolding/extractelement-vscale.ll
new file mode 100644
index 000000000000..c4b42be45019
--- /dev/null
+++ b/llvm/test/Analysis/ConstantFolding/extractelement-vscale.ll
@@ -0,0 +1,13 @@
+; RUN: opt -instcombine -S < %s | FileCheck %s
+
+; CHECK-LABEL: definitely_in_bounds
+; CHECK: ret i8 0
+define i8 @definitely_in_bounds() {
+  ret i8 extractelement (<vscale x 16 x i8> zeroinitializer, i64 15)
+}
+
+; CHECK-LABEL: maybe_in_bounds
+; CHECK: ret i8 extractelement (<vscale x 16 x i8> zeroinitializer, i64 16)
+define i8 @maybe_in_bounds() {
+  ret i8 extractelement (<vscale x 16 x i8> zeroinitializer, i64 16)
+}


        


More information about the llvm-commits mailing list