[llvm] [InstCombine] Handle more scalable geps in EmitGEPOffset (PR #71699)

David Green via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 10 04:30:58 PST 2023


https://github.com/davemgreen updated https://github.com/llvm/llvm-project/pull/71699

>From b6bab71853b12e1649dea76f53ac1b2caf230823 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Fri, 10 Nov 2023 12:29:25 +0000
Subject: [PATCH] [InstCombine] Handle more scalable geps in EmitGEPOffset

Following up on #71565, this makes scalable splats in EmitGEPOffset use the
ElementCount as opposed to assuming it is fixed width, and attempts to handle
scalable offsets with vector geps by splatting the vscale to each vector lane.

It appears that the `& PtrSizeMask` can be removed without altering any of the
tests or any of the test I tried across AArch64/Arm.
---
 llvm/lib/Analysis/Local.cpp                   | 22 ++++----
 .../Transforms/InstCombine/getelementptr.ll   | 53 +++++++++++++++++++
 2 files changed, 62 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Analysis/Local.cpp b/llvm/lib/Analysis/Local.cpp
index ded6007663845e0..e81e78bc77bab0f 100644
--- a/llvm/lib/Analysis/Local.cpp
+++ b/llvm/lib/Analysis/Local.cpp
@@ -36,17 +36,10 @@ Value *llvm::emitGEPOffset(IRBuilderBase *Builder, const DataLayout &DL,
       Result = Offset;
   };
 
-  // Build a mask for high order bits.
-  unsigned IntPtrWidth = IntIdxTy->getScalarType()->getIntegerBitWidth();
-  uint64_t PtrSizeMask =
-      std::numeric_limits<uint64_t>::max() >> (64 - IntPtrWidth);
-
   gep_type_iterator GTI = gep_type_begin(GEP);
   for (User::op_iterator i = GEP->op_begin() + 1, e = GEP->op_end(); i != e;
        ++i, ++GTI) {
     Value *Op = *i;
-    TypeSize TSize = DL.getTypeAllocSize(GTI.getIndexedType());
-    uint64_t Size = TSize.getKnownMinValue() & PtrSizeMask;
     if (Constant *OpC = dyn_cast<Constant>(Op)) {
       if (OpC->isZeroValue())
         continue;
@@ -54,7 +47,7 @@ Value *llvm::emitGEPOffset(IRBuilderBase *Builder, const DataLayout &DL,
       // Handle a struct index, which adds its field offset to the pointer.
       if (StructType *STy = GTI.getStructTypeOrNull()) {
         uint64_t OpValue = OpC->getUniqueInteger().getZExtValue();
-        Size = DL.getStructLayout(STy)->getElementOffset(OpValue);
+        unsigned Size = DL.getStructLayout(STy)->getElementOffset(OpValue);
         if (!Size)
           continue;
 
@@ -66,16 +59,19 @@ Value *llvm::emitGEPOffset(IRBuilderBase *Builder, const DataLayout &DL,
     // Splat the index if needed.
     if (IntIdxTy->isVectorTy() && !Op->getType()->isVectorTy())
       Op = Builder->CreateVectorSplat(
-          cast<FixedVectorType>(IntIdxTy)->getNumElements(), Op);
+          cast<VectorType>(IntIdxTy)->getElementCount(), Op);
 
     // Convert to correct type.
     if (Op->getType() != IntIdxTy)
       Op = Builder->CreateIntCast(Op, IntIdxTy, true, Op->getName() + ".c");
-    if (Size != 1 || TSize.isScalable()) {
-      // We'll let instcombine(mul) convert this to a shl if possible.
-      auto *ScaleC = ConstantInt::get(IntIdxTy, Size);
+    TypeSize TSize = DL.getTypeAllocSize(GTI.getIndexedType());
+    if (TSize != TypeSize::Fixed(1)) {
       Value *Scale =
-          !TSize.isScalable() ? ScaleC : Builder->CreateVScale(ScaleC);
+          Builder->CreateTypeSize(IntIdxTy->getScalarType(), TSize);
+      if (IntIdxTy->isVectorTy())
+        Scale = Builder->CreateVectorSplat(
+            cast<VectorType>(IntIdxTy)->getElementCount(), Scale);
+      // We'll let instcombine(mul) convert this to a shl if possible.
       Op = Builder->CreateMul(Op, Scale, GEP->getName() + ".idx", false /*NUW*/,
                               isInBounds /*NSW*/);
     }
diff --git a/llvm/test/Transforms/InstCombine/getelementptr.ll b/llvm/test/Transforms/InstCombine/getelementptr.ll
index 752dd6f6877dd58..bc7fdc9352df6cd 100644
--- a/llvm/test/Transforms/InstCombine/getelementptr.ll
+++ b/llvm/test/Transforms/InstCombine/getelementptr.ll
@@ -233,6 +233,59 @@ define <2 x i1> @test13_vector2(i64 %X, <2 x ptr> %P) nounwind {
   ret <2 x i1> %C
 }
 
+define <2 x i1> @test13_fixed_fixed(i64 %X, ptr %P, <2 x i64> %y) nounwind {
+; CHECK-LABEL: @test13_fixed_fixed(
+; CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i64> poison, i64 [[X:%.*]], i64 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shl <2 x i64> [[DOTSPLATINSERT]], <i64 3, i64 0>
+; CHECK-NEXT:    [[A_IDX:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[B_IDX:%.*]] = shl nsw <2 x i64> [[Y:%.*]], <i64 4, i64 4>
+; CHECK-NEXT:    [[C:%.*]] = icmp eq <2 x i64> [[A_IDX]], [[B_IDX]]
+; CHECK-NEXT:    ret <2 x i1> [[C]]
+;
+  %A = getelementptr inbounds <2 x i64>, ptr %P, <2 x i64> zeroinitializer, i64 %X
+  %B = getelementptr inbounds <2 x i64>, ptr %P, <2 x i64> %y
+  %C = icmp eq <2 x ptr> %A, %B
+  ret <2 x i1> %C
+}
+
+define <2 x i1> @test13_fixed_scalable(i64 %X, ptr %P, <2 x i64> %y) nounwind {
+; CHECK-LABEL: @test13_fixed_scalable(
+; CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i64> poison, i64 [[X:%.*]], i64 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shl <2 x i64> [[DOTSPLATINSERT]], <i64 3, i64 0>
+; CHECK-NEXT:    [[A_IDX:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP3:%.*]] = shl i64 [[TMP2]], 4
+; CHECK-NEXT:    [[DOTSPLATINSERT1:%.*]] = insertelement <2 x i64> poison, i64 [[TMP3]], i64 0
+; CHECK-NEXT:    [[DOTSPLAT2:%.*]] = shufflevector <2 x i64> [[DOTSPLATINSERT1]], <2 x i64> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[B_IDX:%.*]] = mul nsw <2 x i64> [[DOTSPLAT2]], [[Y:%.*]]
+; CHECK-NEXT:    [[C:%.*]] = icmp eq <2 x i64> [[A_IDX]], [[B_IDX]]
+; CHECK-NEXT:    ret <2 x i1> [[C]]
+;
+  %A = getelementptr inbounds <vscale x 2 x i64>, ptr %P, <2 x i64> zeroinitializer, i64 %X
+  %B = getelementptr inbounds <vscale x 2 x i64>, ptr %P, <2 x i64> %y
+  %C = icmp eq <2 x ptr> %A, %B
+  ret <2 x i1> %C
+}
+
+define <vscale x 2 x i1> @test13_scalable_scalable(i64 %X, ptr %P, <vscale x 2 x i64> %y) nounwind {
+; CHECK-LABEL: @test13_scalable_scalable(
+; CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[X:%.*]], i64 0
+; CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <vscale x 2 x i64> [[DOTSPLATINSERT]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
+; CHECK-NEXT:    [[A_IDX:%.*]] = shl nsw <vscale x 2 x i64> [[DOTSPLAT]], shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> poison, i64 3, i64 0), <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer)
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP2:%.*]] = shl i64 [[TMP1]], 4
+; CHECK-NEXT:    [[DOTSPLATINSERT1:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[TMP2]], i64 0
+; CHECK-NEXT:    [[DOTSPLAT2:%.*]] = shufflevector <vscale x 2 x i64> [[DOTSPLATINSERT1]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
+; CHECK-NEXT:    [[B_IDX:%.*]] = mul nsw <vscale x 2 x i64> [[DOTSPLAT2]], [[Y:%.*]]
+; CHECK-NEXT:    [[C:%.*]] = icmp eq <vscale x 2 x i64> [[A_IDX]], [[B_IDX]]
+; CHECK-NEXT:    ret <vscale x 2 x i1> [[C]]
+;
+  %A = getelementptr inbounds <vscale x 2 x i64>, ptr %P, <vscale x 2 x i64> zeroinitializer, i64 %X
+  %B = getelementptr inbounds <vscale x 2 x i64>, ptr %P, <vscale x 2 x i64> %y
+  %C = icmp eq <vscale x 2 x ptr> %A, %B
+  ret <vscale x 2 x i1> %C
+}
+
 ; This is a test of icmp + shl nuw in disguise - 4611... is 0x3fff...
 define <2 x i1> @test13_vector3(i64 %X, <2 x ptr> %P) nounwind {
 ; CHECK-LABEL: @test13_vector3(



More information about the llvm-commits mailing list