[llvm] 297e06c - [RISCVGatherScatterLowering] Remove restriction that shift must have constant operand
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Fri May 12 12:13:09 PDT 2023
Author: Philip Reames
Date: 2023-05-12T12:08:33-07:00
New Revision: 297e06cf4b03e4c4840c580e6314e9c4c19b856f
URL: https://github.com/llvm/llvm-project/commit/297e06cf4b03e4c4840c580e6314e9c4c19b856f
DIFF: https://github.com/llvm/llvm-project/commit/297e06cf4b03e4c4840c580e6314e9c4c19b856f.diff
LOG: [RISCVGatherScatterLowering] Remove restriction that shift must have constant operand
This has been present from the original patch which added the pass, and doesn't appear to be strongly justified. We do need to be careful of commutativity.
Differential Revision: https://reviews.llvm.org/D150468
Added:
Modified:
llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 0174aa547a1b8..30c1e39022d5d 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -236,9 +236,6 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
case Instruction::Add:
break;
case Instruction::Shl:
- // Only support shift by constant.
- if (!isa<Constant>(BO->getOperand(1)))
- return false;
break;
case Instruction::Mul:
break;
@@ -251,7 +248,8 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
Index = cast<Instruction>(BO->getOperand(0));
OtherOp = BO->getOperand(1);
} else if (isa<Instruction>(BO->getOperand(1)) &&
- L->contains(cast<Instruction>(BO->getOperand(1)))) {
+ L->contains(cast<Instruction>(BO->getOperand(1))) &&
+ Instruction::isCommutative(BO->getOpcode())) {
Index = cast<Instruction>(BO->getOperand(1));
OtherOp = BO->getOperand(0);
} else {
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll
index e3ed26d1a66e6..b59d96819cbfd 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll
@@ -310,6 +310,100 @@ for.cond.cleanup: ; preds = %vector.body
ret void
}
+define void @gather_unknown_pow2(ptr noalias nocapture %A, ptr noalias nocapture readonly %B, i64 %shift) {
+; CHECK-LABEL: @gather_unknown_pow2(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[STEP:%.*]] = shl i64 8, [[SHIFT:%.*]]
+; CHECK-NEXT: [[STRIDE:%.*]] = shl i64 1, [[SHIFT]]
+; CHECK-NEXT: [[TMP0:%.*]] = mul i64 [[STRIDE]], 4
+; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
+; CHECK: vector.body:
+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[ENTRY]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[B:%.*]], i64 [[VEC_IND_SCALAR]]
+; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <8 x i32> @llvm.riscv.masked.strided.load.v8i32.p0.i64(<8 x i32> undef, ptr [[TMP1]], i64 [[TMP0]], <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
+; CHECK-NEXT: [[I2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i32>, ptr [[I2]], align 1
+; CHECK-NEXT: [[I4:%.*]] = add <8 x i32> [[WIDE_LOAD]], [[WIDE_MASKED_GATHER]]
+; CHECK-NEXT: store <8 x i32> [[I4]], ptr [[I2]], align 1
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
+; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR]] = add i64 [[VEC_IND_SCALAR]], [[STEP]]
+; CHECK-NEXT: [[I6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
+; CHECK-NEXT: br i1 [[I6]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]]
+; CHECK: for.cond.cleanup:
+; CHECK-NEXT: ret void
+;
+entry:
+ %.splatinsert = insertelement <8 x i64> poison, i64 %shift, i64 0
+ %.splat = shufflevector <8 x i64> %.splatinsert, <8 x i64> poison, <8 x i32> zeroinitializer
+ br label %vector.body
+
+vector.body: ; preds = %vector.body, %entry
+ %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
+ %vec.ind = phi <8 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7>, %entry ], [ %vec.ind.next, %vector.body ]
+ %i = shl nsw <8 x i64> %vec.ind, %.splat
+ %i1 = getelementptr inbounds i32, ptr %B, <8 x i64> %i
+ %wide.masked.gather = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> %i1, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i32> undef)
+ %i2 = getelementptr inbounds i32, ptr %A, i64 %index
+ %wide.load = load <8 x i32>, ptr %i2, align 1
+ %i4 = add <8 x i32> %wide.load, %wide.masked.gather
+ store <8 x i32> %i4, ptr %i2, align 1
+ %index.next = add nuw i64 %index, 8
+ %vec.ind.next = add <8 x i64> %vec.ind, <i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8>
+ %i6 = icmp eq i64 %index.next, 1024
+ br i1 %i6, label %for.cond.cleanup, label %vector.body
+
+for.cond.cleanup: ; preds = %vector.body
+ ret void
+}
+
+define void @negative_shl_non_commute(ptr noalias nocapture %A, ptr noalias nocapture readonly %B, i64 %shift) {
+; CHECK-LABEL: @negative_shl_non_commute(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <8 x i64> poison, i64 [[SHIFT:%.*]], i64 0
+; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <8 x i64> [[DOTSPLATINSERT]], <8 x i64> poison, <8 x i32> zeroinitializer
+; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
+; CHECK: vector.body:
+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_IND:%.*]] = phi <8 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7>, [[ENTRY]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[I:%.*]] = shl nsw <8 x i64> [[DOTSPLAT]], [[VEC_IND]]
+; CHECK-NEXT: [[I1:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], <8 x i64> [[I]]
+; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> [[I1]], i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i32> undef)
+; CHECK-NEXT: [[I2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
+; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i32>, ptr [[I2]], align 1
+; CHECK-NEXT: [[I4:%.*]] = add <8 x i32> [[WIDE_LOAD]], [[WIDE_MASKED_GATHER]]
+; CHECK-NEXT: store <8 x i32> [[I4]], ptr [[I2]], align 1
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
+; CHECK-NEXT: [[VEC_IND_NEXT]] = add <8 x i64> [[VEC_IND]], <i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8>
+; CHECK-NEXT: [[I6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
+; CHECK-NEXT: br i1 [[I6]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]]
+; CHECK: for.cond.cleanup:
+; CHECK-NEXT: ret void
+;
+entry:
+ %.splatinsert = insertelement <8 x i64> poison, i64 %shift, i64 0
+ %.splat = shufflevector <8 x i64> %.splatinsert, <8 x i64> poison, <8 x i32> zeroinitializer
+ br label %vector.body
+
+vector.body: ; preds = %vector.body, %entry
+ %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
+ %vec.ind = phi <8 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7>, %entry ], [ %vec.ind.next, %vector.body ]
+ %i = shl nsw <8 x i64> %.splat, %vec.ind
+ %i1 = getelementptr inbounds i32, ptr %B, <8 x i64> %i
+ %wide.masked.gather = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> %i1, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i32> undef)
+ %i2 = getelementptr inbounds i32, ptr %A, i64 %index
+ %wide.load = load <8 x i32>, ptr %i2, align 1
+ %i4 = add <8 x i32> %wide.load, %wide.masked.gather
+ store <8 x i32> %i4, ptr %i2, align 1
+ %index.next = add nuw i64 %index, 8
+ %vec.ind.next = add <8 x i64> %vec.ind, <i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8>
+ %i6 = icmp eq i64 %index.next, 1024
+ br i1 %i6, label %for.cond.cleanup, label %vector.body
+
+for.cond.cleanup: ; preds = %vector.body
+ ret void
+}
+
;void scatter_pow2(signed char * __restrict A, signed char * __restrict B) {
; for (int i = 0; i < 1024; ++i)
; A[i * 4] += B[i];
More information about the llvm-commits
mailing list