[llvm] 999ac10 - [RISCVGatherScatterLowering] Support broadcast base pointer

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 7 07:42:09 PDT 2023


Author: Philip Reames
Date: 2023-08-07T07:42:04-07:00
New Revision: 999ac10d7649e41755a9624dbb508c2db8bf3ddd

URL: https://github.com/llvm/llvm-project/commit/999ac10d7649e41755a9624dbb508c2db8bf3ddd
DIFF: https://github.com/llvm/llvm-project/commit/999ac10d7649e41755a9624dbb508c2db8bf3ddd.diff

LOG: [RISCVGatherScatterLowering] Support broadcast base pointer

A broadcast base pointer is the same as a scalar base pointer for GEP semantics (when there's at least one other vector operand). This is the form that SLP likes to emit, so we should handle it.

Differential Revision: https://reviews.llvm.org/D157132

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-store.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index b9c69a966b4ae8..fac3526c43148d 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -331,8 +331,12 @@ RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
   SmallVector<Value *, 2> Ops(GEP->operands());
 
   // Base pointer needs to be a scalar.
-  if (Ops[0]->getType()->isVectorTy())
-    return std::make_pair(nullptr, nullptr);
+  Value *ScalarBase = Ops[0];
+  if (ScalarBase->getType()->isVectorTy()) {
+    ScalarBase = getSplatValue(ScalarBase);
+    if (!ScalarBase)
+      return std::make_pair(nullptr, nullptr);
+  }
 
   std::optional<unsigned> VecOperand;
   unsigned TypeScale = 0;
@@ -379,7 +383,7 @@ RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
     Ops[*VecOperand] = Start;
     Type *SourceTy = GEP->getSourceElementType();
     Value *BasePtr =
-        Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
+        Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front());
 
     // Convert stride to pointer size if needed.
     Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
@@ -415,7 +419,7 @@ RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
   Ops[*VecOperand] = BasePhi;
   Type *SourceTy = GEP->getSourceElementType();
   Value *BasePtr =
-      Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
+      Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front());
 
   // Final adjustments to stride should go in the start block.
   Builder.SetInsertPoint(

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-store.ll
index ef8c2c19231178..7e1995fc0de4b8 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-store.ll
@@ -947,3 +947,19 @@ bb4:                                              ; preds = %bb4, %bb2
 bb16:                                             ; preds = %bb4, %bb
   ret void
 }
+
+define <8 x i8> @broadcast_ptr_base(ptr %a) {
+; CHECK-LABEL: @broadcast_ptr_base(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call <8 x i8> @llvm.riscv.masked.strided.load.v8i8.p0.i64(<8 x i8> poison, ptr [[A:%.*]], i64 64, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
+; CHECK-NEXT:    ret <8 x i8> [[TMP0]]
+;
+entry:
+  %0 = insertelement <8 x ptr> poison, ptr %a, i64 0
+  %1 = shufflevector <8 x ptr> %0, <8 x ptr> poison, <8 x i32> zeroinitializer
+  %2 = getelementptr i8, <8 x ptr> %1, <8 x i64> <i64 0, i64 64, i64 128, i64 192, i64 256, i64 320, i64 384, i64 448>
+  %3 = tail call <8 x i8> @llvm.masked.gather.v8i8.v8p0(<8 x ptr> %2, i32 1, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i8> poison)
+  ret <8 x i8> %3
+}
+
+declare <8 x i8> @llvm.masked.gather.v8i8.v8p0(<8 x ptr>, i32 immarg, <8 x i1>, <8 x i8>)


        


More information about the llvm-commits mailing list