[llvm] 4cc2010 - [RISCVGatherScatterLowering] Use InstSimplifyFolder

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Fri May 12 17:33:16 PDT 2023


Author: Philip Reames
Date: 2023-05-12T17:33:07-07:00
New Revision: 4cc2010f0368b71316acd5895f9a016e7f34c5fb

URL: https://github.com/llvm/llvm-project/commit/4cc2010f0368b71316acd5895f9a016e7f34c5fb
DIFF: https://github.com/llvm/llvm-project/commit/4cc2010f0368b71316acd5895f9a016e7f34c5fb.diff

LOG: [RISCVGatherScatterLowering] Use InstSimplifyFolder

Main value of this is simplifying code, and making a few of the tests easier to read.

Differential Revision: https://reviews.llvm.org/D150474

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll
    llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 5e527c60ca5f..b9c69a966b4a 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -13,6 +13,7 @@
 
 #include "RISCV.h"
 #include "RISCVTargetMachine.h"
+#include "llvm/Analysis/InstSimplifyFolder.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/VectorUtils.h"
@@ -67,11 +68,11 @@ class RISCVGatherScatterLowering : public FunctionPass {
                                  Value *AlignOp);
 
   std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP,
-                                                     IRBuilder<> &Builder);
+                                                     IRBuilderBase &Builder);
 
   bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
                               PHINode *&BasePtr, BinaryOperator *&Inc,
-                              IRBuilder<> &Builder);
+                              IRBuilderBase &Builder);
 };
 
 } // end anonymous namespace
@@ -119,7 +120,7 @@ static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
 }
 
 static std::pair<Value *, Value *> matchStridedStart(Value *Start,
-                                                     IRBuilder<> &Builder) {
+                                                     IRBuilderBase &Builder) {
   // Base case, start is a strided constant.
   auto *StartC = dyn_cast<Constant>(Start);
   if (StartC)
@@ -186,7 +187,7 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
                                                         Value *&Stride,
                                                         PHINode *&BasePtr,
                                                         BinaryOperator *&Inc,
-                                                        IRBuilder<> &Builder) {
+                                                        IRBuilderBase &Builder) {
   // Our base case is a Phi.
   if (auto *Phi = dyn_cast<PHINode>(Index)) {
     // A phi node we want to perform this function on should be from the
@@ -297,32 +298,17 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
   case Instruction::Or: {
     // An add only affects the start value. It's ok to do this for Or because
     // we already checked that there are no common set bits.
-
-    // If the start value is Zero, just take the SplatOp.
-    if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero())
-      Start = SplatOp;
-    else
-      Start = Builder.CreateAdd(Start, SplatOp, "start");
+    Start = Builder.CreateAdd(Start, SplatOp, "start");
     break;
   }
   case Instruction::Mul: {
-    // If the start is zero we don't need to multiply.
-    if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
-      Start = Builder.CreateMul(Start, SplatOp, "start");
-
+    Start = Builder.CreateMul(Start, SplatOp, "start");
     Step = Builder.CreateMul(Step, SplatOp, "step");
-
-    // If the Stride is 1 just take the SplatOpt.
-    if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne())
-      Stride = SplatOp;
-    else
-      Stride = Builder.CreateMul(Stride, SplatOp, "stride");
+    Stride = Builder.CreateMul(Stride, SplatOp, "stride");
     break;
   }
   case Instruction::Shl: {
-    // If the start is zero we don't need to shift.
-    if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero())
-      Start = Builder.CreateShl(Start, SplatOp, "start");
+    Start = Builder.CreateShl(Start, SplatOp, "start");
     Step = Builder.CreateShl(Step, SplatOp, "step");
     Stride = Builder.CreateShl(Stride, SplatOp, "stride");
     break;
@@ -336,7 +322,7 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
 
 std::pair<Value *, Value *>
 RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
-                                                   IRBuilder<> &Builder) {
+                                                   IRBuilderBase &Builder) {
 
   auto I = StridedAddrs.find(GEP);
   if (I != StridedAddrs.end())
@@ -467,7 +453,9 @@ bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
   if (!GEP)
     return false;
 
-  IRBuilder<> Builder(GEP);
+  LLVMContext &Ctx = GEP->getContext();
+  IRBuilder<InstSimplifyFolder> Builder(Ctx, *DL);
+  Builder.SetInsertPoint(GEP);
 
   Value *BasePtr, *Stride;
   std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder);

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll
index b59d96819cbf..ef8c2c192311 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll
@@ -794,15 +794,14 @@ define void @strided_load_startval_add_with_splat(ptr noalias nocapture %arg, pt
 ; CHECK:       bb9:
 ; CHECK-NEXT:    [[I10:%.*]] = and i64 [[I7]], 8589934560
 ; CHECK-NEXT:    [[I11:%.*]] = add nsw i64 [[I10]], [[I4]]
-; CHECK-NEXT:    [[TMP0:%.*]] = add i64 0, [[I4]]
-; CHECK-NEXT:    [[START:%.*]] = mul i64 [[TMP0]], 5
+; CHECK-NEXT:    [[START:%.*]] = mul i64 [[I4]], 5
 ; CHECK-NEXT:    br label [[BB15:%.*]]
 ; CHECK:       bb15:
 ; CHECK-NEXT:    [[I16:%.*]] = phi i64 [ 0, [[BB9]] ], [ [[I27:%.*]], [[BB15]] ]
 ; CHECK-NEXT:    [[I17_SCALAR:%.*]] = phi i64 [ [[START]], [[BB9]] ], [ [[I28_SCALAR:%.*]], [[BB15]] ]
 ; CHECK-NEXT:    [[I18:%.*]] = add i64 [[I16]], [[I4]]
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i8, ptr [[ARG1:%.*]], i64 [[I17_SCALAR]]
-; CHECK-NEXT:    [[I21:%.*]] = call <32 x i8> @llvm.riscv.masked.strided.load.v32i8.p0.i64(<32 x i8> undef, ptr [[TMP1]], i64 5, <32 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr i8, ptr [[ARG1:%.*]], i64 [[I17_SCALAR]]
+; CHECK-NEXT:    [[I21:%.*]] = call <32 x i8> @llvm.riscv.masked.strided.load.v32i8.p0.i64(<32 x i8> undef, ptr [[TMP0]], i64 5, <32 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
 ; CHECK-NEXT:    [[I22:%.*]] = getelementptr inbounds i8, ptr [[ARG:%.*]], i64 [[I18]]
 ; CHECK-NEXT:    [[I24:%.*]] = load <32 x i8>, ptr [[I22]], align 1
 ; CHECK-NEXT:    [[I25:%.*]] = add <32 x i8> [[I24]], [[I21]]

diff  --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
index e5a2da6b8060..fcb3742eb236 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
@@ -91,11 +91,8 @@ for.cond.cleanup:                                 ; preds = %vector.body
 
 define <vscale x 1 x i64> @gather_loopless(ptr %p, i64 %stride) {
 ; CHECK-LABEL: @gather_loopless(
-; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]]
-; CHECK-NEXT:    [[TMP2:%.*]] = mul i64 1, [[STRIDE]]
-; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
-; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP2]], 4
-; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP3]], i64 [[TMP4]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[STRIDE:%.*]], 4
+; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[P:%.*]], i64 [[TMP1]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret <vscale x 1 x i64> [[X]]
 ;
   %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
@@ -114,9 +111,8 @@ define <vscale x 1 x i64> @gather_loopless(ptr %p, i64 %stride) {
 
 define <vscale x 1 x i64> @straightline_offset_add(ptr %p, i64 %offset) {
 ; CHECK-LABEL: @straightline_offset_add(
-; CHECK-NEXT:    [[TMP1:%.*]] = add i64 0, [[OFFSET:%.*]]
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
-; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP2]], i64 4, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[OFFSET:%.*]]
+; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP1]], i64 4, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret <vscale x 1 x i64> [[X]]
 ;
   %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
@@ -135,8 +131,7 @@ define <vscale x 1 x i64> @straightline_offset_add(ptr %p, i64 %offset) {
 
 define <vscale x 1 x i64> @straightline_offset_shl(ptr %p) {
 ; CHECK-LABEL: @straightline_offset_shl(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 0
-; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP1]], i64 32, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
+; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[P:%.*]], i64 32, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret <vscale x 1 x i64> [[X]]
 ;
   %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
@@ -179,11 +174,9 @@ define <vscale x 1 x i64> @neg_shl_is_not_commutative(ptr %p) {
 
 define <vscale x 1 x i64> @straightline_offset_shl_nonc(ptr %p, i64 %shift) {
 ; CHECK-LABEL: @straightline_offset_shl_nonc(
-; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 0, [[SHIFT:%.*]]
-; CHECK-NEXT:    [[TMP2:%.*]] = shl i64 1, [[SHIFT]]
-; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
-; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP2]], 4
-; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP3]], i64 [[TMP4]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 1, [[SHIFT:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i64 [[TMP1]], 4
+; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[P:%.*]], i64 [[TMP2]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret <vscale x 1 x i64> [[X]]
 ;
   %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
@@ -202,11 +195,8 @@ define <vscale x 1 x i64> @straightline_offset_shl_nonc(ptr %p, i64 %shift) {
 
 define void @scatter_loopless(<vscale x 1 x i64> %x, ptr %p, i64 %stride) {
 ; CHECK-LABEL: @scatter_loopless(
-; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]]
-; CHECK-NEXT:    [[TMP2:%.*]] = mul i64 1, [[STRIDE]]
-; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
-; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP2]], 4
-; CHECK-NEXT:    call void @llvm.riscv.masked.strided.store.nxv1i64.p0.i64(<vscale x 1 x i64> [[X:%.*]], ptr [[TMP3]], i64 [[TMP4]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[STRIDE:%.*]], 4
+; CHECK-NEXT:    call void @llvm.riscv.masked.strided.store.nxv1i64.p0.i64(<vscale x 1 x i64> [[X:%.*]], ptr [[P:%.*]], i64 [[TMP1]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret void
 ;
   %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()


        


More information about the llvm-commits mailing list