[llvm] r319683 - [ConstantFold] Support vector index when factoring out GEP index into preceding dimensions

Haicheng Wu via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 4 11:56:33 PST 2017


Author: haicheng
Date: Mon Dec  4 11:56:33 2017
New Revision: 319683

URL: http://llvm.org/viewvc/llvm-project?rev=319683&view=rev
Log:
[ConstantFold] Support vector index when factoring out GEP index into preceding dimensions

Follow-up of r316824. This patch supports the vector type for both current and
previous index when factoring out the current one into the previous one.

Differential Revision: https://reviews.llvm.org/D39556

Modified:
    llvm/trunk/lib/IR/ConstantFold.cpp
    llvm/trunk/test/Assembler/getelementptr_vec_ce.ll
    llvm/trunk/test/Transforms/InstCombine/gep-vector.ll
    llvm/trunk/test/Transforms/InstSimplify/vector_gep.ll

Modified: llvm/trunk/lib/IR/ConstantFold.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/IR/ConstantFold.cpp?rev=319683&r1=319682&r2=319683&view=diff
==============================================================================
--- llvm/trunk/lib/IR/ConstantFold.cpp (original)
+++ llvm/trunk/lib/IR/ConstantFold.cpp Mon Dec  4 11:56:33 2017
@@ -2210,17 +2210,17 @@ Constant *llvm::ConstantFoldGetElementPt
   SmallVector<Constant *, 8> NewIdxs;
   Type *Ty = PointeeTy;
   Type *Prev = C->getType();
-  bool Unknown = !isa<ConstantInt>(Idxs[0]);
+  bool Unknown =
+      !isa<ConstantInt>(Idxs[0]) && !isa<ConstantDataVector>(Idxs[0]);
   for (unsigned i = 1, e = Idxs.size(); i != e;
        Prev = Ty, Ty = cast<CompositeType>(Ty)->getTypeAtIndex(Idxs[i]), ++i) {
-    auto *CI = dyn_cast<ConstantInt>(Idxs[i]);
-    if (!CI) {
+    if (!isa<ConstantInt>(Idxs[i]) && !isa<ConstantDataVector>(Idxs[i])) {
       // We don't know if it's in range or not.
       Unknown = true;
       continue;
     }
-    if (!isa<ConstantInt>(Idxs[i - 1]))
-      // FIXME: add the support of cosntant vector index.
+    if (!isa<ConstantInt>(Idxs[i - 1]) && !isa<ConstantDataVector>(Idxs[i - 1]))
+      // Skip if the type of the previous index is not supported.
       continue;
     if (InRangeIndex && i == *InRangeIndex + 1) {
       // If an index is marked inrange, we cannot apply this canonicalization to
@@ -2238,46 +2238,91 @@ Constant *llvm::ConstantFoldGetElementPt
       Unknown = true;
       continue;
     }
-    if (isIndexInRangeOfArrayType(STy->getNumElements(), CI))
-      // It's in range, skip to the next index.
-      continue;
+    if (ConstantInt *CI = dyn_cast<ConstantInt>(Idxs[i])) {
+      if (isIndexInRangeOfArrayType(STy->getNumElements(), CI))
+        // It's in range, skip to the next index.
+        continue;
+      if (CI->getSExtValue() < 0) {
+        // It's out of range and negative, don't try to factor it.
+        Unknown = true;
+        continue;
+      }
+    } else {
+      auto *CV = cast<ConstantDataVector>(Idxs[i]);
+      bool InRange = true;
+      for (unsigned I = 0, E = CV->getNumElements(); I != E; ++I) {
+        auto *CI = cast<ConstantInt>(CV->getElementAsConstant(I));
+        InRange &= isIndexInRangeOfArrayType(STy->getNumElements(), CI);
+        if (CI->getSExtValue() < 0) {
+          Unknown = true;
+          break;
+        }
+      }
+      if (InRange || Unknown)
+        // It's in range, skip to the next index.
+        // It's out of range and negative, don't try to factor it.
+        continue;
+    }
     if (isa<StructType>(Prev)) {
       // It's out of range, but the prior dimension is a struct
       // so we can't do anything about it.
       Unknown = true;
       continue;
     }
-    if (CI->getSExtValue() < 0) {
-      // It's out of range and negative, don't try to factor it.
-      Unknown = true;
-      continue;
-    }
     // It's out of range, but we can factor it into the prior
     // dimension.
     NewIdxs.resize(Idxs.size());
     // Determine the number of elements in our sequential type.
     uint64_t NumElements = STy->getArrayNumElements();
 
-    ConstantInt *Factor = ConstantInt::get(CI->getType(), NumElements);
-    NewIdxs[i] = ConstantExpr::getSRem(CI, Factor);
+    // Expand the current index or the previous index to a vector from a scalar
+    // if necessary.
+    Constant *CurrIdx = cast<Constant>(Idxs[i]);
+    auto *PrevIdx =
+        NewIdxs[i - 1] ? NewIdxs[i - 1] : cast<Constant>(Idxs[i - 1]);
+    bool IsCurrIdxVector = CurrIdx->getType()->isVectorTy();
+    bool IsPrevIdxVector = PrevIdx->getType()->isVectorTy();
+    bool UseVector = IsCurrIdxVector || IsPrevIdxVector;
+
+    if (!IsCurrIdxVector && IsPrevIdxVector)
+      CurrIdx = ConstantDataVector::getSplat(
+          PrevIdx->getType()->getVectorNumElements(), CurrIdx);
 
-    Constant *PrevIdx = NewIdxs[i-1] ? NewIdxs[i-1] :
-                           cast<Constant>(Idxs[i - 1]);
-    Constant *Div = ConstantExpr::getSDiv(CI, Factor);
+    if (!IsPrevIdxVector && IsCurrIdxVector)
+      PrevIdx = ConstantDataVector::getSplat(
+          CurrIdx->getType()->getVectorNumElements(), PrevIdx);
+
+    Constant *Factor =
+        ConstantInt::get(CurrIdx->getType()->getScalarType(), NumElements);
+    if (UseVector)
+      Factor = ConstantDataVector::getSplat(
+          IsPrevIdxVector ? PrevIdx->getType()->getVectorNumElements()
+                          : CurrIdx->getType()->getVectorNumElements(),
+          Factor);
+
+    NewIdxs[i] = ConstantExpr::getSRem(CurrIdx, Factor);
+
+    Constant *Div = ConstantExpr::getSDiv(CurrIdx, Factor);
 
     unsigned CommonExtendedWidth =
-        std::max(PrevIdx->getType()->getIntegerBitWidth(),
-                 Div->getType()->getIntegerBitWidth());
+        std::max(PrevIdx->getType()->getScalarSizeInBits(),
+                 Div->getType()->getScalarSizeInBits());
     CommonExtendedWidth = std::max(CommonExtendedWidth, 64U);
 
     // Before adding, extend both operands to i64 to avoid
     // overflow trouble.
-    if (!PrevIdx->getType()->isIntegerTy(CommonExtendedWidth))
-      PrevIdx = ConstantExpr::getSExt(
-          PrevIdx, Type::getIntNTy(Div->getContext(), CommonExtendedWidth));
-    if (!Div->getType()->isIntegerTy(CommonExtendedWidth))
-      Div = ConstantExpr::getSExt(
-          Div, Type::getIntNTy(Div->getContext(), CommonExtendedWidth));
+    Type *ExtendedTy = Type::getIntNTy(Div->getContext(), CommonExtendedWidth);
+    if (UseVector)
+      ExtendedTy = VectorType::get(
+          ExtendedTy, IsPrevIdxVector
+                          ? PrevIdx->getType()->getVectorNumElements()
+                          : CurrIdx->getType()->getVectorNumElements());
+
+    if (!PrevIdx->getType()->isIntOrIntVectorTy(CommonExtendedWidth))
+      PrevIdx = ConstantExpr::getSExt(PrevIdx, ExtendedTy);
+
+    if (!Div->getType()->isIntOrIntVectorTy(CommonExtendedWidth))
+      Div = ConstantExpr::getSExt(Div, ExtendedTy);
 
     NewIdxs[i - 1] = ConstantExpr::getAdd(PrevIdx, Div);
   }

Modified: llvm/trunk/test/Assembler/getelementptr_vec_ce.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Assembler/getelementptr_vec_ce.ll?rev=319683&r1=319682&r2=319683&view=diff
==============================================================================
--- llvm/trunk/test/Assembler/getelementptr_vec_ce.ll (original)
+++ llvm/trunk/test/Assembler/getelementptr_vec_ce.ll Mon Dec  4 11:56:33 2017
@@ -3,7 +3,7 @@
 @G = global [4 x i32] zeroinitializer
 
 ; CHECK-LABEL: @foo
-; CHECK: ret <4 x i32*> getelementptr ([4 x i32], [4 x i32]* @G, <4 x i32> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>)
+; CHECK: ret <4 x i32*> getelementptr inbounds ([4 x i32], [4 x i32]* @G, <4 x i32> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>)
 define <4 x i32*> @foo() {
   ret <4 x i32*> getelementptr ([4 x i32], [4 x i32]* @G, i32 0, <4 x i32> <i32 0, i32 1, i32 2, i32 3>)
 }

Modified: llvm/trunk/test/Transforms/InstCombine/gep-vector.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/gep-vector.ll?rev=319683&r1=319682&r2=319683&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/gep-vector.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/gep-vector.ll Mon Dec  4 11:56:33 2017
@@ -16,9 +16,23 @@ define <8 x i64*> @patatino2() {
 
 @block = global [64 x [8192 x i8]] zeroinitializer, align 1
 
-; CHECK-LABEL:vectorindex
-; CHECK-NEXT: ret <2 x i8*> getelementptr inbounds ([64 x [8192 x i8]], [64 x [8192 x i8]]* @block, <2 x i64> zeroinitializer, <2 x i64> <i64 0, i64 1>, <2 x i64> <i64 8192, i64 8192>)
-define <2 x i8*> @vectorindex() {
+; CHECK-LABEL:vectorindex1
+; CHECK-NEXT: ret <2 x i8*> getelementptr inbounds ([64 x [8192 x i8]], [64 x [8192 x i8]]* @block, <2 x i64> zeroinitializer, <2 x i64> <i64 1, i64 2>, <2 x i64> zeroinitializer)
+define <2 x i8*> @vectorindex1() {
   %1 = getelementptr inbounds [64 x [8192 x i8]], [64 x [8192 x i8]]* @block, i64 0, <2 x i64> <i64 0, i64 1>, i64 8192
   ret <2 x i8*> %1
 }
+
+; CHECK-LABEL:vectorindex2
+; CHECK-NEXT: ret <2 x i8*> getelementptr inbounds ([64 x [8192 x i8]], [64 x [8192 x i8]]* @block, <2 x i64> zeroinitializer, <2 x i64> <i64 1, i64 2>, <2 x i64> <i64 8191, i64 1>)
+define <2 x i8*> @vectorindex2() {
+  %1 = getelementptr inbounds [64 x [8192 x i8]], [64 x [8192 x i8]]* @block, i64 0, i64 1, <2 x i64> <i64 8191, i64 8193>
+  ret <2 x i8*> %1
+}
+
+; CHECK-LABEL:vectorindex3
+; CHECK-NEXT: ret <2 x i8*> getelementptr inbounds ([64 x [8192 x i8]], [64 x [8192 x i8]]* @block, <2 x i64> zeroinitializer, <2 x i64> <i64 0, i64 2>, <2 x i64> <i64 8191, i64 1>)
+define <2 x i8*> @vectorindex3() {
+  %1 = getelementptr inbounds [64 x [8192 x i8]], [64 x [8192 x i8]]* @block, i64 0, <2 x i64> <i64 0, i64 1>, <2 x i64> <i64 8191, i64 8193>
+  ret <2 x i8*> %1
+}

Modified: llvm/trunk/test/Transforms/InstSimplify/vector_gep.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstSimplify/vector_gep.ll?rev=319683&r1=319682&r2=319683&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstSimplify/vector_gep.ll (original)
+++ llvm/trunk/test/Transforms/InstSimplify/vector_gep.ll Mon Dec  4 11:56:33 2017
@@ -58,7 +58,7 @@ define <4 x i8*> @test5() {
 
 define <16 x i32*> @test6() {
 ; CHECK-LABEL: @test6
-; CHECK-NEXT: ret <16 x i32*> getelementptr ([24 x [42 x [3 x i32]]], [24 x [42 x [3 x i32]]]* @v, <16 x i64> zeroinitializer, <16 x i64> zeroinitializer, <16 x i64> <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9, i64 10, i64 11, i64 12, i64 13, i64 14, i64 15>, <16 x i64> zeroinitializer)
+; CHECK-NEXT: ret <16 x i32*> getelementptr inbounds ([24 x [42 x [3 x i32]]], [24 x [42 x [3 x i32]]]* @v, <16 x i64> zeroinitializer, <16 x i64> zeroinitializer, <16 x i64> <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9, i64 10, i64 11, i64 12, i64 13, i64 14, i64 15>, <16 x i64> zeroinitializer)
   %VectorGep = getelementptr [24 x [42 x [3 x i32]]], [24 x [42 x [3 x i32]]]* @v, i64 0, i64 0, <16 x i64> <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9, i64 10, i64 11, i64 12, i64 13, i64 14, i64 15>, i64 0
   ret <16 x i32*> %VectorGep
 }




More information about the llvm-commits mailing list