[llvm] [RISCV] Match gather(splat(ptr)) as zero strided load (PR #65769)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 8 08:15:44 PDT 2023


https://github.com/preames created https://github.com/llvm/llvm-project/pull/65769:

We were already handling the case where the broadcast was being done via a GEP, but we hadn't handled the case of a broadcast via a shuffle.

>From 4e2096852f26edb361c7caadfbf56ac771011cd9 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Fri, 8 Sep 2023 07:57:12 -0700
Subject: [PATCH] [RISCV] Match gather(splat(ptr)) as zero strided load

We were already handling the case where the broadcast was being done via a GEP, but we hadn't handled the case of a broadcast via a shuffle.
---
 .../RISCV/RISCVGatherScatterLowering.cpp      | 30 ++++++++----
 .../RISCV/rvv/fixed-vectors-masked-gather.ll  | 47 ++-----------------
 2 files changed, 23 insertions(+), 54 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index fac3526c43148d8..0e9244d0aefa813 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -67,7 +67,7 @@ class RISCVGatherScatterLowering : public FunctionPass {
   bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
                                  Value *AlignOp);
 
-  std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP,
+  std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr,
                                                      IRBuilderBase &Builder);
 
   bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
@@ -321,9 +321,19 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
 }
 
 std::pair<Value *, Value *>
-RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
+RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
                                                    IRBuilderBase &Builder) {
 
+  // A gather/scatter of a splat is a zero strided load/store.
+  if (auto *BasePtr = getSplatValue(Ptr)) {
+    Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
+    return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0));
+  }
+
+  auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
+  if (!GEP)
+    return std::make_pair(nullptr, nullptr);
+
   auto I = StridedAddrs.find(GEP);
   if (I != StridedAddrs.end())
     return I->second;
@@ -452,17 +462,17 @@ bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
   if (!TLI->isTypeLegal(DataTypeVT))
     return false;
 
-  // Pointer should be a GEP.
-  auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
-  if (!GEP)
+  // Pointer should be an instruction.
+  auto *PtrI = dyn_cast<Instruction>(Ptr);
+  if (!PtrI)
     return false;
 
-  LLVMContext &Ctx = GEP->getContext();
+  LLVMContext &Ctx = PtrI->getContext();
   IRBuilder<InstSimplifyFolder> Builder(Ctx, *DL);
-  Builder.SetInsertPoint(GEP);
+  Builder.SetInsertPoint(PtrI);
 
   Value *BasePtr, *Stride;
-  std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder);
+  std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder);
   if (!BasePtr)
     return false;
   assert(Stride != nullptr);
@@ -485,8 +495,8 @@ bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
   II->replaceAllUsesWith(Call);
   II->eraseFromParent();
 
-  if (GEP->use_empty())
-    RecursivelyDeleteTriviallyDeadInstructions(GEP);
+  if (PtrI->use_empty())
+    RecursivelyDeleteTriviallyDeadInstructions(PtrI);
 
   return true;
 }
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
index cb3ee899dde7d27..25ef59e111faed9 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
@@ -12918,60 +12918,19 @@ define <4 x i32> @mgather_broadcast_load_unmasked2(ptr %base) {
 ; RV32-LABEL: mgather_broadcast_load_unmasked2:
 ; RV32:       # %bb.0:
 ; RV32-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; RV32-NEXT:    vmv.v.x v8, a0
-; RV32-NEXT:    vluxei32.v v8, (zero), v8
+; RV32-NEXT:    vlse32.v v8, (a0), zero
 ; RV32-NEXT:    ret
 ;
 ; RV64V-LABEL: mgather_broadcast_load_unmasked2:
 ; RV64V:       # %bb.0:
-; RV64V-NEXT:    vsetivli zero, 4, e64, m2, ta, ma
-; RV64V-NEXT:    vmv.v.x v10, a0
-; RV64V-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
-; RV64V-NEXT:    vluxei64.v v8, (zero), v10
+; RV64V-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; RV64V-NEXT:    vlse32.v v8, (a0), zero
 ; RV64V-NEXT:    ret
 ;
 ; RV64ZVE32F-LABEL: mgather_broadcast_load_unmasked2:
 ; RV64ZVE32F:       # %bb.0:
-; RV64ZVE32F-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; RV64ZVE32F-NEXT:    vmset.m v8
-; RV64ZVE32F-NEXT:    vmv.x.s a1, v8
-; RV64ZVE32F-NEXT:    # implicit-def: $v8
-; RV64ZVE32F-NEXT:    beqz zero, .LBB100_5
-; RV64ZVE32F-NEXT:  # %bb.1: # %else
-; RV64ZVE32F-NEXT:    andi a2, a1, 2
-; RV64ZVE32F-NEXT:    bnez a2, .LBB100_6
-; RV64ZVE32F-NEXT:  .LBB100_2: # %else2
-; RV64ZVE32F-NEXT:    andi a2, a1, 4
-; RV64ZVE32F-NEXT:    bnez a2, .LBB100_7
-; RV64ZVE32F-NEXT:  .LBB100_3: # %else5
-; RV64ZVE32F-NEXT:    andi a1, a1, 8
-; RV64ZVE32F-NEXT:    bnez a1, .LBB100_8
-; RV64ZVE32F-NEXT:  .LBB100_4: # %else8
-; RV64ZVE32F-NEXT:    ret
-; RV64ZVE32F-NEXT:  .LBB100_5: # %cond.load
 ; RV64ZVE32F-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
 ; RV64ZVE32F-NEXT:    vlse32.v v8, (a0), zero
-; RV64ZVE32F-NEXT:    andi a2, a1, 2
-; RV64ZVE32F-NEXT:    beqz a2, .LBB100_2
-; RV64ZVE32F-NEXT:  .LBB100_6: # %cond.load1
-; RV64ZVE32F-NEXT:    lw a2, 0(a0)
-; RV64ZVE32F-NEXT:    vsetivli zero, 2, e32, m1, tu, ma
-; RV64ZVE32F-NEXT:    vmv.s.x v9, a2
-; RV64ZVE32F-NEXT:    vslideup.vi v8, v9, 1
-; RV64ZVE32F-NEXT:    andi a2, a1, 4
-; RV64ZVE32F-NEXT:    beqz a2, .LBB100_3
-; RV64ZVE32F-NEXT:  .LBB100_7: # %cond.load4
-; RV64ZVE32F-NEXT:    lw a2, 0(a0)
-; RV64ZVE32F-NEXT:    vsetivli zero, 3, e32, m1, tu, ma
-; RV64ZVE32F-NEXT:    vmv.s.x v9, a2
-; RV64ZVE32F-NEXT:    vslideup.vi v8, v9, 2
-; RV64ZVE32F-NEXT:    andi a1, a1, 8
-; RV64ZVE32F-NEXT:    beqz a1, .LBB100_4
-; RV64ZVE32F-NEXT:  .LBB100_8: # %cond.load7
-; RV64ZVE32F-NEXT:    lw a0, 0(a0)
-; RV64ZVE32F-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; RV64ZVE32F-NEXT:    vmv.s.x v9, a0
-; RV64ZVE32F-NEXT:    vslideup.vi v8, v9, 3
 ; RV64ZVE32F-NEXT:    ret
   %head = insertelement <4 x i1> poison, i1 true, i32 0
   %allones = shufflevector <4 x i1> %head, <4 x i1> poison, <4 x i32> zeroinitializer



More information about the llvm-commits mailing list