[llvm] 73af455 - [InstCombine] Handle more scalable geps in EmitGEPOffset (#71699)

via llvm-commits llvm-commits at lists.llvm.org
Sat Nov 11 10:21:35 PST 2023


Author: David Green
Date: 2023-11-11T18:21:31Z
New Revision: 73af455f57dad2ae69c73ca27435be83ddedc6f3

URL: https://github.com/llvm/llvm-project/commit/73af455f57dad2ae69c73ca27435be83ddedc6f3
DIFF: https://github.com/llvm/llvm-project/commit/73af455f57dad2ae69c73ca27435be83ddedc6f3.diff

LOG: [InstCombine] Handle more scalable geps in EmitGEPOffset (#71699)

Following up on #71565, this makes scalable splats in EmitGEPOffset use
the ElementCount as opposed to assuming it is fixed width, and attempts
to handle scalable offsets with vector geps by splatting the vscale to
each vector lane.

Added: 
    

Modified: 
    llvm/lib/Analysis/Local.cpp
    llvm/test/Transforms/InstCombine/getelementptr.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/Local.cpp b/llvm/lib/Analysis/Local.cpp
index ded6007663845e0..8dd1a7e2975347d 100644
--- a/llvm/lib/Analysis/Local.cpp
+++ b/llvm/lib/Analysis/Local.cpp
@@ -36,17 +36,10 @@ Value *llvm::emitGEPOffset(IRBuilderBase *Builder, const DataLayout &DL,
       Result = Offset;
   };
 
-  // Build a mask for high order bits.
-  unsigned IntPtrWidth = IntIdxTy->getScalarType()->getIntegerBitWidth();
-  uint64_t PtrSizeMask =
-      std::numeric_limits<uint64_t>::max() >> (64 - IntPtrWidth);
-
   gep_type_iterator GTI = gep_type_begin(GEP);
   for (User::op_iterator i = GEP->op_begin() + 1, e = GEP->op_end(); i != e;
        ++i, ++GTI) {
     Value *Op = *i;
-    TypeSize TSize = DL.getTypeAllocSize(GTI.getIndexedType());
-    uint64_t Size = TSize.getKnownMinValue() & PtrSizeMask;
     if (Constant *OpC = dyn_cast<Constant>(Op)) {
       if (OpC->isZeroValue())
         continue;
@@ -54,7 +47,7 @@ Value *llvm::emitGEPOffset(IRBuilderBase *Builder, const DataLayout &DL,
       // Handle a struct index, which adds its field offset to the pointer.
       if (StructType *STy = GTI.getStructTypeOrNull()) {
         uint64_t OpValue = OpC->getUniqueInteger().getZExtValue();
-        Size = DL.getStructLayout(STy)->getElementOffset(OpValue);
+        uint64_t Size = DL.getStructLayout(STy)->getElementOffset(OpValue);
         if (!Size)
           continue;
 
@@ -66,16 +59,18 @@ Value *llvm::emitGEPOffset(IRBuilderBase *Builder, const DataLayout &DL,
     // Splat the index if needed.
     if (IntIdxTy->isVectorTy() && !Op->getType()->isVectorTy())
       Op = Builder->CreateVectorSplat(
-          cast<FixedVectorType>(IntIdxTy)->getNumElements(), Op);
+          cast<VectorType>(IntIdxTy)->getElementCount(), Op);
 
     // Convert to correct type.
     if (Op->getType() != IntIdxTy)
       Op = Builder->CreateIntCast(Op, IntIdxTy, true, Op->getName() + ".c");
-    if (Size != 1 || TSize.isScalable()) {
+    TypeSize TSize = DL.getTypeAllocSize(GTI.getIndexedType());
+    if (TSize != TypeSize::Fixed(1)) {
+      Value *Scale = Builder->CreateTypeSize(IntIdxTy->getScalarType(), TSize);
+      if (IntIdxTy->isVectorTy())
+        Scale = Builder->CreateVectorSplat(
+            cast<VectorType>(IntIdxTy)->getElementCount(), Scale);
       // We'll let instcombine(mul) convert this to a shl if possible.
-      auto *ScaleC = ConstantInt::get(IntIdxTy, Size);
-      Value *Scale =
-          !TSize.isScalable() ? ScaleC : Builder->CreateVScale(ScaleC);
       Op = Builder->CreateMul(Op, Scale, GEP->getName() + ".idx", false /*NUW*/,
                               isInBounds /*NSW*/);
     }

diff  --git a/llvm/test/Transforms/InstCombine/getelementptr.ll b/llvm/test/Transforms/InstCombine/getelementptr.ll
index 752dd6f6877dd58..bc7fdc9352df6cd 100644
--- a/llvm/test/Transforms/InstCombine/getelementptr.ll
+++ b/llvm/test/Transforms/InstCombine/getelementptr.ll
@@ -233,6 +233,59 @@ define <2 x i1> @test13_vector2(i64 %X, <2 x ptr> %P) nounwind {
   ret <2 x i1> %C
 }
 
+define <2 x i1> @test13_fixed_fixed(i64 %X, ptr %P, <2 x i64> %y) nounwind {
+; CHECK-LABEL: @test13_fixed_fixed(
+; CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i64> poison, i64 [[X:%.*]], i64 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shl <2 x i64> [[DOTSPLATINSERT]], <i64 3, i64 0>
+; CHECK-NEXT:    [[A_IDX:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[B_IDX:%.*]] = shl nsw <2 x i64> [[Y:%.*]], <i64 4, i64 4>
+; CHECK-NEXT:    [[C:%.*]] = icmp eq <2 x i64> [[A_IDX]], [[B_IDX]]
+; CHECK-NEXT:    ret <2 x i1> [[C]]
+;
+  %A = getelementptr inbounds <2 x i64>, ptr %P, <2 x i64> zeroinitializer, i64 %X
+  %B = getelementptr inbounds <2 x i64>, ptr %P, <2 x i64> %y
+  %C = icmp eq <2 x ptr> %A, %B
+  ret <2 x i1> %C
+}
+
+define <2 x i1> @test13_fixed_scalable(i64 %X, ptr %P, <2 x i64> %y) nounwind {
+; CHECK-LABEL: @test13_fixed_scalable(
+; CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i64> poison, i64 [[X:%.*]], i64 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shl <2 x i64> [[DOTSPLATINSERT]], <i64 3, i64 0>
+; CHECK-NEXT:    [[A_IDX:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP3:%.*]] = shl i64 [[TMP2]], 4
+; CHECK-NEXT:    [[DOTSPLATINSERT1:%.*]] = insertelement <2 x i64> poison, i64 [[TMP3]], i64 0
+; CHECK-NEXT:    [[DOTSPLAT2:%.*]] = shufflevector <2 x i64> [[DOTSPLATINSERT1]], <2 x i64> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[B_IDX:%.*]] = mul nsw <2 x i64> [[DOTSPLAT2]], [[Y:%.*]]
+; CHECK-NEXT:    [[C:%.*]] = icmp eq <2 x i64> [[A_IDX]], [[B_IDX]]
+; CHECK-NEXT:    ret <2 x i1> [[C]]
+;
+  %A = getelementptr inbounds <vscale x 2 x i64>, ptr %P, <2 x i64> zeroinitializer, i64 %X
+  %B = getelementptr inbounds <vscale x 2 x i64>, ptr %P, <2 x i64> %y
+  %C = icmp eq <2 x ptr> %A, %B
+  ret <2 x i1> %C
+}
+
+define <vscale x 2 x i1> @test13_scalable_scalable(i64 %X, ptr %P, <vscale x 2 x i64> %y) nounwind {
+; CHECK-LABEL: @test13_scalable_scalable(
+; CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[X:%.*]], i64 0
+; CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <vscale x 2 x i64> [[DOTSPLATINSERT]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
+; CHECK-NEXT:    [[A_IDX:%.*]] = shl nsw <vscale x 2 x i64> [[DOTSPLAT]], shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 3, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer)
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP2:%.*]] = shl i64 [[TMP1]], 4
+; CHECK-NEXT:    [[DOTSPLATINSERT1:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[TMP2]], i64 0
+; CHECK-NEXT:    [[DOTSPLAT2:%.*]] = shufflevector <vscale x 2 x i64> [[DOTSPLATINSERT1]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
+; CHECK-NEXT:    [[B_IDX:%.*]] = mul nsw <vscale x 2 x i64> [[DOTSPLAT2]], [[Y:%.*]]
+; CHECK-NEXT:    [[C:%.*]] = icmp eq <vscale x 2 x i64> [[A_IDX]], [[B_IDX]]
+; CHECK-NEXT:    ret <vscale x 2 x i1> [[C]]
+;
+  %A = getelementptr inbounds <vscale x 2 x i64>, ptr %P, <vscale x 2 x i64> zeroinitializer, i64 %X
+  %B = getelementptr inbounds <vscale x 2 x i64>, ptr %P, <vscale x 2 x i64> %y
+  %C = icmp eq <vscale x 2 x ptr> %A, %B
+  ret <vscale x 2 x i1> %C
+}
+
 ; This is a test of icmp + shl nuw in disguise - 4611... is 0x3fff...
 define <2 x i1> @test13_vector3(i64 %X, <2 x ptr> %P) nounwind {
 ; CHECK-LABEL: @test13_vector3(


        


More information about the llvm-commits mailing list