[llvm] df2e728 - [RISCV] Teach RISCVGatherScatterLowering to handle more complex recurrence start values.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 4 10:13:45 PST 2022


Author: Craig Topper
Date: 2022-01-04T10:13:34-08:00
New Revision: df2e728b77510da33cf3822eae4d66531eeed518

URL: https://github.com/llvm/llvm-project/commit/df2e728b77510da33cf3822eae4d66531eeed518
DIFF: https://github.com/llvm/llvm-project/commit/df2e728b77510da33cf3822eae4d66531eeed518.diff

LOG: [RISCV] Teach RISCVGatherScatterLowering to handle more complex recurrence start values.

Previously we only recognized strided loads/store when the initial
value for the phi was a strided constant vector.

This patch extends the support to a strided_constant added to a
splatted value. The rewritten loop will add the splat value to the
first element of the strided constant vector to use as the scalar
start value. The stride is unaffected.

Reviewed By: frasercrmck

Differential Revision: https://reviews.llvm.org/D115958

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index d47bd739235fe..ba91b16661a46 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -127,6 +127,41 @@ static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
   return std::make_pair(StartVal, Stride);
 }
 
+static std::pair<Value *, Value *> matchStridedStart(Value *Start,
+                                                     IRBuilder<> &Builder) {
+  // Base case, start is a strided constant.
+  auto *StartC = dyn_cast<Constant>(Start);
+  if (StartC)
+    return matchStridedConstant(StartC);
+
+  // Not a constant, maybe it's a strided constant with a splat added to it.
+  auto *BO = dyn_cast<BinaryOperator>(Start);
+  if (!BO || BO->getOpcode() != Instruction::Add)
+    return std::make_pair(nullptr, nullptr);
+
+  // Look for an operand that is splatted.
+  unsigned OtherIndex = 1;
+  Value *Splat = getSplatValue(BO->getOperand(0));
+  if (!Splat) {
+    Splat = getSplatValue(BO->getOperand(1));
+    OtherIndex = 0;
+  }
+  if (!Splat)
+    return std::make_pair(nullptr, nullptr);
+
+  Value *Stride;
+  std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),
+                                              Builder);
+  if (!Start)
+    return std::make_pair(nullptr, nullptr);
+
+  // Add the splat value to the start.
+  Builder.SetInsertPoint(BO);
+  Builder.SetCurrentDebugLocation(DebugLoc());
+  Start = Builder.CreateAdd(Start, Splat);
+  return std::make_pair(Start, Stride);
+}
+
 // Recursively, walk about the use-def chain until we find a Phi with a strided
 // start value. Build and update a scalar recurrence as we unwind the recursion.
 // We also update the Stride as we unwind. Our goal is to move all of the
@@ -161,12 +196,7 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
     if (!Step)
       return false;
 
-    // Start should be a strided constant.
-    auto *StartC = dyn_cast<Constant>(Start);
-    if (!StartC)
-      return false;
-
-    std::tie(Start, Stride) = matchStridedConstant(StartC);
+    std::tie(Start, Stride) = matchStridedStart(Start, Builder);
     if (!Start)
       return false;
     assert(Stride != nullptr);

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll
index 026e149c1a464..e563b0834d607 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll
@@ -978,3 +978,173 @@ define void @scatter_of_pointers(i32** noalias nocapture %0, i32** noalias nocap
 }
 
 declare void @llvm.masked.scatter.v2p0i32.v2p0p0i32(<2 x i32*>, <2 x i32**>, i32 immarg, <2 x i1>)
+
+define void @strided_load_startval_add_with_splat(i8* noalias nocapture %0, i8* noalias nocapture readonly %1, i32 signext %2) {
+; CHECK-LABEL: @strided_load_startval_add_with_splat(
+; CHECK-NEXT:    [[TMP4:%.*]] = icmp eq i32 [[TMP2:%.*]], 1024
+; CHECK-NEXT:    br i1 [[TMP4]], label [[TMP31:%.*]], label [[TMP5:%.*]]
+; CHECK:       5:
+; CHECK-NEXT:    [[TMP6:%.*]] = sext i32 [[TMP2]] to i64
+; CHECK-NEXT:    [[TMP7:%.*]] = sub i32 1023, [[TMP2]]
+; CHECK-NEXT:    [[TMP8:%.*]] = zext i32 [[TMP7]] to i64
+; CHECK-NEXT:    [[TMP9:%.*]] = add nuw nsw i64 [[TMP8]], 1
+; CHECK-NEXT:    [[TMP10:%.*]] = icmp ult i32 [[TMP7]], 31
+; CHECK-NEXT:    br i1 [[TMP10]], label [[TMP29:%.*]], label [[TMP11:%.*]]
+; CHECK:       11:
+; CHECK-NEXT:    [[TMP12:%.*]] = and i64 [[TMP9]], 8589934560
+; CHECK-NEXT:    [[TMP13:%.*]] = add nsw i64 [[TMP12]], [[TMP6]]
+; CHECK-NEXT:    [[TMP14:%.*]] = add i64 0, [[TMP6]]
+; CHECK-NEXT:    [[START:%.*]] = mul i64 [[TMP14]], 5
+; CHECK-NEXT:    br label [[TMP15:%.*]]
+; CHECK:       15:
+; CHECK-NEXT:    [[TMP16:%.*]] = phi i64 [ 0, [[TMP11]] ], [ [[TMP25:%.*]], [[TMP15]] ]
+; CHECK-NEXT:    [[DOTSCALAR:%.*]] = phi i64 [ [[START]], [[TMP11]] ], [ [[DOTSCALAR1:%.*]], [[TMP15]] ]
+; CHECK-NEXT:    [[TMP17:%.*]] = add i64 [[TMP16]], [[TMP6]]
+; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr i8, i8* [[TMP1:%.*]], i64 [[DOTSCALAR]]
+; CHECK-NEXT:    [[TMP19:%.*]] = call <32 x i8> @llvm.riscv.masked.strided.load.v32i8.p0i8.i64(<32 x i8> undef, i8* [[TMP18]], i64 5, <32 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
+; CHECK-NEXT:    [[TMP20:%.*]] = getelementptr inbounds i8, i8* [[TMP0:%.*]], i64 [[TMP17]]
+; CHECK-NEXT:    [[TMP21:%.*]] = bitcast i8* [[TMP20]] to <32 x i8>*
+; CHECK-NEXT:    [[TMP22:%.*]] = load <32 x i8>, <32 x i8>* [[TMP21]], align 1
+; CHECK-NEXT:    [[TMP23:%.*]] = add <32 x i8> [[TMP22]], [[TMP19]]
+; CHECK-NEXT:    [[TMP24:%.*]] = bitcast i8* [[TMP20]] to <32 x i8>*
+; CHECK-NEXT:    store <32 x i8> [[TMP23]], <32 x i8>* [[TMP24]], align 1
+; CHECK-NEXT:    [[TMP25]] = add nuw i64 [[TMP16]], 32
+; CHECK-NEXT:    [[DOTSCALAR1]] = add i64 [[DOTSCALAR]], 160
+; CHECK-NEXT:    [[TMP26:%.*]] = icmp eq i64 [[TMP25]], [[TMP12]]
+; CHECK-NEXT:    br i1 [[TMP26]], label [[TMP27:%.*]], label [[TMP15]]
+; CHECK:       27:
+; CHECK-NEXT:    [[TMP28:%.*]] = icmp eq i64 [[TMP9]], [[TMP12]]
+; CHECK-NEXT:    br i1 [[TMP28]], label [[TMP31]], label [[TMP29]]
+; CHECK:       29:
+; CHECK-NEXT:    [[TMP30:%.*]] = phi i64 [ [[TMP6]], [[TMP5]] ], [ [[TMP13]], [[TMP27]] ]
+; CHECK-NEXT:    br label [[TMP32:%.*]]
+; CHECK:       31:
+; CHECK-NEXT:    ret void
+; CHECK:       32:
+; CHECK-NEXT:    [[TMP33:%.*]] = phi i64 [ [[TMP40:%.*]], [[TMP32]] ], [ [[TMP30]], [[TMP29]] ]
+; CHECK-NEXT:    [[TMP34:%.*]] = mul nsw i64 [[TMP33]], 5
+; CHECK-NEXT:    [[TMP35:%.*]] = getelementptr inbounds i8, i8* [[TMP1]], i64 [[TMP34]]
+; CHECK-NEXT:    [[TMP36:%.*]] = load i8, i8* [[TMP35]], align 1
+; CHECK-NEXT:    [[TMP37:%.*]] = getelementptr inbounds i8, i8* [[TMP0]], i64 [[TMP33]]
+; CHECK-NEXT:    [[TMP38:%.*]] = load i8, i8* [[TMP37]], align 1
+; CHECK-NEXT:    [[TMP39:%.*]] = add i8 [[TMP38]], [[TMP36]]
+; CHECK-NEXT:    store i8 [[TMP39]], i8* [[TMP37]], align 1
+; CHECK-NEXT:    [[TMP40]] = add nsw i64 [[TMP33]], 1
+; CHECK-NEXT:    [[TMP41:%.*]] = trunc i64 [[TMP40]] to i32
+; CHECK-NEXT:    [[TMP42:%.*]] = icmp eq i32 [[TMP41]], 1024
+; CHECK-NEXT:    br i1 [[TMP42]], label [[TMP31]], label [[TMP32]]
+;
+; CHECK-ASM-LABEL: strided_load_startval_add_with_splat:
+; CHECK-ASM:       # %bb.0:
+; CHECK-ASM-NEXT:    li a3, 1024
+; CHECK-ASM-NEXT:    beq a2, a3, .LBB12_7
+; CHECK-ASM-NEXT:  # %bb.1:
+; CHECK-ASM-NEXT:    li a3, 1023
+; CHECK-ASM-NEXT:    subw a4, a3, a2
+; CHECK-ASM-NEXT:    li a5, 31
+; CHECK-ASM-NEXT:    mv a3, a2
+; CHECK-ASM-NEXT:    bltu a4, a5, .LBB12_5
+; CHECK-ASM-NEXT:  # %bb.2:
+; CHECK-ASM-NEXT:    slli a3, a4, 32
+; CHECK-ASM-NEXT:    srli a3, a3, 32
+; CHECK-ASM-NEXT:    addi a6, a3, 1
+; CHECK-ASM-NEXT:    andi a7, a6, -32
+; CHECK-ASM-NEXT:    add a3, a7, a2
+; CHECK-ASM-NEXT:    slli a4, a2, 2
+; CHECK-ASM-NEXT:    add a4, a4, a2
+; CHECK-ASM-NEXT:    add a2, a0, a2
+; CHECK-ASM-NEXT:    add a4, a1, a4
+; CHECK-ASM-NEXT:    li t0, 32
+; CHECK-ASM-NEXT:    li t1, 5
+; CHECK-ASM-NEXT:    mv a5, a7
+; CHECK-ASM-NEXT:  .LBB12_3: # =>This Inner Loop Header: Depth=1
+; CHECK-ASM-NEXT:    vsetvli zero, t0, e8, m1, ta, mu
+; CHECK-ASM-NEXT:    vlse8.v v8, (a4), t1
+; CHECK-ASM-NEXT:    vle8.v v9, (a2)
+; CHECK-ASM-NEXT:    vadd.vv v8, v9, v8
+; CHECK-ASM-NEXT:    vse8.v v8, (a2)
+; CHECK-ASM-NEXT:    addi a5, a5, -32
+; CHECK-ASM-NEXT:    addi a2, a2, 32
+; CHECK-ASM-NEXT:    addi a4, a4, 160
+; CHECK-ASM-NEXT:    bnez a5, .LBB12_3
+; CHECK-ASM-NEXT:  # %bb.4:
+; CHECK-ASM-NEXT:    beq a6, a7, .LBB12_7
+; CHECK-ASM-NEXT:  .LBB12_5:
+; CHECK-ASM-NEXT:    slli a2, a3, 2
+; CHECK-ASM-NEXT:    add a2, a2, a3
+; CHECK-ASM-NEXT:    add a1, a1, a2
+; CHECK-ASM-NEXT:    li a6, 1024
+; CHECK-ASM-NEXT:  .LBB12_6: # =>This Inner Loop Header: Depth=1
+; CHECK-ASM-NEXT:    lb a4, 0(a1)
+; CHECK-ASM-NEXT:    add a5, a0, a3
+; CHECK-ASM-NEXT:    lb a2, 0(a5)
+; CHECK-ASM-NEXT:    addw a2, a2, a4
+; CHECK-ASM-NEXT:    sb a2, 0(a5)
+; CHECK-ASM-NEXT:    addiw a2, a3, 1
+; CHECK-ASM-NEXT:    addi a3, a3, 1
+; CHECK-ASM-NEXT:    addi a1, a1, 5
+; CHECK-ASM-NEXT:    bne a2, a6, .LBB12_6
+; CHECK-ASM-NEXT:  .LBB12_7:
+; CHECK-ASM-NEXT:    ret
+  %4 = icmp eq i32 %2, 1024
+  br i1 %4, label %36, label %5
+
+5:                                                ; preds = %3
+  %6 = sext i32 %2 to i64
+  %7 = sub i32 1023, %2
+  %8 = zext i32 %7 to i64
+  %9 = add nuw nsw i64 %8, 1
+  %10 = icmp ult i32 %7, 31
+  br i1 %10, label %34, label %11
+
+11:                                               ; preds = %5
+  %12 = and i64 %9, 8589934560
+  %13 = add nsw i64 %12, %6
+  %14 = insertelement <32 x i64> poison, i64 %6, i64 0
+  %15 = shufflevector <32 x i64> %14, <32 x i64> poison, <32 x i32> zeroinitializer
+  %16 = add <32 x i64> %15, <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9, i64 10, i64 11, i64 12, i64 13, i64 14, i64 15, i64 16, i64 17, i64 18, i64 19, i64 20, i64 21, i64 22, i64 23, i64 24, i64 25, i64 26, i64 27, i64 28, i64 29, i64 30, i64 31>
+  br label %17
+
+17:                                               ; preds = %17, %11
+  %18 = phi i64 [ 0, %11 ], [ %29, %17 ]
+  %19 = phi <32 x i64> [ %16, %11 ], [ %30, %17 ]
+  %20 = add i64 %18, %6
+  %21 = mul nsw <32 x i64> %19, <i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5>
+  %22 = getelementptr inbounds i8, i8* %1, <32 x i64> %21
+  %23 = call <32 x i8> @llvm.masked.gather.v32i8.v32p0i8(<32 x i8*> %22, i32 1, <32 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <32 x i8> undef)
+  %24 = getelementptr inbounds i8, i8* %0, i64 %20
+  %25 = bitcast i8* %24 to <32 x i8>*
+  %26 = load <32 x i8>, <32 x i8>* %25, align 1
+  %27 = add <32 x i8> %26, %23
+  %28 = bitcast i8* %24 to <32 x i8>*
+  store <32 x i8> %27, <32 x i8>* %28, align 1
+  %29 = add nuw i64 %18, 32
+  %30 = add <32 x i64> %19, <i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32>
+  %31 = icmp eq i64 %29, %12
+  br i1 %31, label %32, label %17
+
+32:                                               ; preds = %17
+  %33 = icmp eq i64 %9, %12
+  br i1 %33, label %36, label %34
+
+34:                                               ; preds = %5, %32
+  %35 = phi i64 [ %6, %5 ], [ %13, %32 ]
+  br label %37
+
+36:                                               ; preds = %37, %32, %3
+  ret void
+
+37:                                               ; preds = %34, %37
+  %38 = phi i64 [ %45, %37 ], [ %35, %34 ]
+  %39 = mul nsw i64 %38, 5
+  %40 = getelementptr inbounds i8, i8* %1, i64 %39
+  %41 = load i8, i8* %40, align 1
+  %42 = getelementptr inbounds i8, i8* %0, i64 %38
+  %43 = load i8, i8* %42, align 1
+  %44 = add i8 %43, %41
+  store i8 %44, i8* %42, align 1
+  %45 = add nsw i64 %38, 1
+  %46 = trunc i64 %45 to i32
+  %47 = icmp eq i32 %46, 1024
+  br i1 %47, label %36, label %37
+}


        


More information about the llvm-commits mailing list