[llvm] 297e06c - [RISCVGatherScatterLowering] Remove restriction that shift must have constant operand

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Fri May 12 12:13:09 PDT 2023


Author: Philip Reames
Date: 2023-05-12T12:08:33-07:00
New Revision: 297e06cf4b03e4c4840c580e6314e9c4c19b856f

URL: https://github.com/llvm/llvm-project/commit/297e06cf4b03e4c4840c580e6314e9c4c19b856f
DIFF: https://github.com/llvm/llvm-project/commit/297e06cf4b03e4c4840c580e6314e9c4c19b856f.diff

LOG: [RISCVGatherScatterLowering] Remove restriction that shift must have constant operand

This has been present from the original patch which added the pass, and doesn't appear to be strongly justified. We do need to be careful of commutativity.

Differential Revision: https://reviews.llvm.org/D150468

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 0174aa547a1b8..30c1e39022d5d 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -236,9 +236,6 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
   case Instruction::Add:
     break;
   case Instruction::Shl:
-    // Only support shift by constant.
-    if (!isa<Constant>(BO->getOperand(1)))
-      return false;
     break;
   case Instruction::Mul:
     break;
@@ -251,7 +248,8 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
     Index = cast<Instruction>(BO->getOperand(0));
     OtherOp = BO->getOperand(1);
   } else if (isa<Instruction>(BO->getOperand(1)) &&
-             L->contains(cast<Instruction>(BO->getOperand(1)))) {
+             L->contains(cast<Instruction>(BO->getOperand(1))) &&
+             Instruction::isCommutative(BO->getOpcode())) {
     Index = cast<Instruction>(BO->getOperand(1));
     OtherOp = BO->getOperand(0);
   } else {

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll
index e3ed26d1a66e6..b59d96819cbfd 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll
@@ -310,6 +310,100 @@ for.cond.cleanup:                                 ; preds = %vector.body
   ret void
 }
 
+define void @gather_unknown_pow2(ptr noalias nocapture %A, ptr noalias nocapture readonly %B, i64 %shift) {
+; CHECK-LABEL: @gather_unknown_pow2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[STEP:%.*]] = shl i64 8, [[SHIFT:%.*]]
+; CHECK-NEXT:    [[STRIDE:%.*]] = shl i64 1, [[SHIFT]]
+; CHECK-NEXT:    [[TMP0:%.*]] = mul i64 [[STRIDE]], 4
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[ENTRY]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[B:%.*]], i64 [[VEC_IND_SCALAR]]
+; CHECK-NEXT:    [[WIDE_MASKED_GATHER:%.*]] = call <8 x i32> @llvm.riscv.masked.strided.load.v8i32.p0.i64(<8 x i32> undef, ptr [[TMP1]], i64 [[TMP0]], <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
+; CHECK-NEXT:    [[I2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i32>, ptr [[I2]], align 1
+; CHECK-NEXT:    [[I4:%.*]] = add <8 x i32> [[WIDE_LOAD]], [[WIDE_MASKED_GATHER]]
+; CHECK-NEXT:    store <8 x i32> [[I4]], ptr [[I2]], align 1
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
+; CHECK-NEXT:    [[VEC_IND_NEXT_SCALAR]] = add i64 [[VEC_IND_SCALAR]], [[STEP]]
+; CHECK-NEXT:    [[I6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
+; CHECK-NEXT:    br i1 [[I6]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %.splatinsert = insertelement <8 x i64> poison, i64 %shift, i64 0
+  %.splat = shufflevector <8 x i64> %.splatinsert, <8 x i64> poison, <8 x i32> zeroinitializer
+  br label %vector.body
+
+vector.body:                                      ; preds = %vector.body, %entry
+  %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
+  %vec.ind = phi <8 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7>, %entry ], [ %vec.ind.next, %vector.body ]
+  %i = shl nsw <8 x i64> %vec.ind, %.splat
+  %i1 = getelementptr inbounds i32, ptr %B, <8 x i64> %i
+  %wide.masked.gather = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> %i1, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i32> undef)
+  %i2 = getelementptr inbounds i32, ptr %A, i64 %index
+  %wide.load = load <8 x i32>, ptr %i2, align 1
+  %i4 = add <8 x i32> %wide.load, %wide.masked.gather
+  store <8 x i32> %i4, ptr %i2, align 1
+  %index.next = add nuw i64 %index, 8
+  %vec.ind.next = add <8 x i64> %vec.ind, <i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8>
+  %i6 = icmp eq i64 %index.next, 1024
+  br i1 %i6, label %for.cond.cleanup, label %vector.body
+
+for.cond.cleanup:                                 ; preds = %vector.body
+  ret void
+}
+
+define void @negative_shl_non_commute(ptr noalias nocapture %A, ptr noalias nocapture readonly %B, i64 %shift) {
+; CHECK-LABEL: @negative_shl_non_commute(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <8 x i64> poison, i64 [[SHIFT:%.*]], i64 0
+; CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <8 x i64> [[DOTSPLATINSERT]], <8 x i64> poison, <8 x i32> zeroinitializer
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_IND:%.*]] = phi <8 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7>, [[ENTRY]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[I:%.*]] = shl nsw <8 x i64> [[DOTSPLAT]], [[VEC_IND]]
+; CHECK-NEXT:    [[I1:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], <8 x i64> [[I]]
+; CHECK-NEXT:    [[WIDE_MASKED_GATHER:%.*]] = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> [[I1]], i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i32> undef)
+; CHECK-NEXT:    [[I2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i32>, ptr [[I2]], align 1
+; CHECK-NEXT:    [[I4:%.*]] = add <8 x i32> [[WIDE_LOAD]], [[WIDE_MASKED_GATHER]]
+; CHECK-NEXT:    store <8 x i32> [[I4]], ptr [[I2]], align 1
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
+; CHECK-NEXT:    [[VEC_IND_NEXT]] = add <8 x i64> [[VEC_IND]], <i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8>
+; CHECK-NEXT:    [[I6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
+; CHECK-NEXT:    br i1 [[I6]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %.splatinsert = insertelement <8 x i64> poison, i64 %shift, i64 0
+  %.splat = shufflevector <8 x i64> %.splatinsert, <8 x i64> poison, <8 x i32> zeroinitializer
+  br label %vector.body
+
+vector.body:                                      ; preds = %vector.body, %entry
+  %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
+  %vec.ind = phi <8 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7>, %entry ], [ %vec.ind.next, %vector.body ]
+  %i = shl nsw <8 x i64> %.splat, %vec.ind
+  %i1 = getelementptr inbounds i32, ptr %B, <8 x i64> %i
+  %wide.masked.gather = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> %i1, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i32> undef)
+  %i2 = getelementptr inbounds i32, ptr %A, i64 %index
+  %wide.load = load <8 x i32>, ptr %i2, align 1
+  %i4 = add <8 x i32> %wide.load, %wide.masked.gather
+  store <8 x i32> %i4, ptr %i2, align 1
+  %index.next = add nuw i64 %index, 8
+  %vec.ind.next = add <8 x i64> %vec.ind, <i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8>
+  %i6 = icmp eq i64 %index.next, 1024
+  br i1 %i6, label %for.cond.cleanup, label %vector.body
+
+for.cond.cleanup:                                 ; preds = %vector.body
+  ret void
+}
+
 ;void scatter_pow2(signed char * __restrict  A, signed char * __restrict B) {
 ;  for (int i = 0; i < 1024; ++i)
 ;      A[i * 4] += B[i];


        


More information about the llvm-commits mailing list