[llvm] 182b831 - [DAGCombiner][RISCV] Teach visitMGATHER/MSCATTER to remove gather/scatters with all zeros masks that use SPLAT_VECTOR.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 18 15:39:37 PDT 2021


Author: Craig Topper
Date: 2021-03-18T15:34:14-07:00
New Revision: 182b831aebc0569e8344d848fa20f0c67f43d55a

URL: https://github.com/llvm/llvm-project/commit/182b831aebc0569e8344d848fa20f0c67f43d55a
DIFF: https://github.com/llvm/llvm-project/commit/182b831aebc0569e8344d848fa20f0c67f43d55a.diff

LOG: [DAGCombiner][RISCV] Teach visitMGATHER/MSCATTER to remove gather/scatters with all zeros masks that use SPLAT_VECTOR.

Previously only all zeros BUILD_VECTOR was recognized.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/RISCV/rvv/mgather-sdnode.ll
    llvm/test/CodeGen/RISCV/rvv/mscatter-sdnode.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 1c063dae9d88..382fc91285a0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -9618,7 +9618,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
   SDLoc DL(N);
 
   // Zap scatters with a zero mask.
-  if (ISD::isBuildVectorAllZeros(Mask.getNode()))
+  if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
     return Chain;
 
   if (refineUniformBase(BasePtr, Index, DAG)) {
@@ -9674,7 +9674,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
   SDLoc DL(N);
 
   // Zap gathers with a zero mask.
-  if (ISD::isBuildVectorAllZeros(Mask.getNode()))
+  if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
     return CombineTo(N, PassThru, MGT->getChain());
 
   if (refineUniformBase(BasePtr, Index, DAG)) {

diff  --git a/llvm/test/CodeGen/RISCV/rvv/mgather-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/mgather-sdnode.ll
index c5f9ea8aa3e3..d567ff9a0140 100644
--- a/llvm/test/CodeGen/RISCV/rvv/mgather-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/mgather-sdnode.ll
@@ -210,6 +210,20 @@ define <vscale x 4 x i8> @mgather_truemask_nxv4i8(<vscale x 4 x i8*> %ptrs, <vsc
   ret <vscale x 4 x i8> %v
 }
 
+define <vscale x 4 x i8> @mgather_falsemask_nxv4i8(<vscale x 4 x i8*> %ptrs, <vscale x 4 x i8> %passthru) {
+; RV32-LABEL: mgather_falsemask_nxv4i8:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vmv1r.v v8, v10
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mgather_falsemask_nxv4i8:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vmv1r.v v8, v12
+; RV64-NEXT:    ret
+  %v = call <vscale x 4 x i8> @llvm.masked.gather.nxv4i8.nxv4p0i8(<vscale x 4 x i8*> %ptrs, i32 1, <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i8> %passthru)
+  ret <vscale x 4 x i8> %v
+}
+
 declare <vscale x 8 x i8> @llvm.masked.gather.nxv8i8.nxv8p0i8(<vscale x 8 x i8*>, i32, <vscale x 8 x i1>, <vscale x 8 x i8>)
 
 define <vscale x 8 x i8> @mgather_nxv8i8(<vscale x 8 x i8*> %ptrs, <vscale x 8 x i1> %m, <vscale x 8 x i8> %passthru) {
@@ -417,6 +431,20 @@ define <vscale x 4 x i16> @mgather_truemask_nxv4i16(<vscale x 4 x i16*> %ptrs, <
   ret <vscale x 4 x i16> %v
 }
 
+define <vscale x 4 x i16> @mgather_falsemask_nxv4i16(<vscale x 4 x i16*> %ptrs, <vscale x 4 x i16> %passthru) {
+; RV32-LABEL: mgather_falsemask_nxv4i16:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vmv1r.v v8, v10
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mgather_falsemask_nxv4i16:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vmv1r.v v8, v12
+; RV64-NEXT:    ret
+  %v = call <vscale x 4 x i16> @llvm.masked.gather.nxv4i16.nxv4p0i16(<vscale x 4 x i16*> %ptrs, i32 2, <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i16> %passthru)
+  ret <vscale x 4 x i16> %v
+}
+
 declare <vscale x 8 x i16> @llvm.masked.gather.nxv8i16.nxv8p0i16(<vscale x 8 x i16*>, i32, <vscale x 8 x i1>, <vscale x 8 x i16>)
 
 define <vscale x 8 x i16> @mgather_nxv8i16(<vscale x 8 x i16*> %ptrs, <vscale x 8 x i1> %m, <vscale x 8 x i16> %passthru) {
@@ -661,6 +689,20 @@ define <vscale x 4 x i32> @mgather_truemask_nxv4i32(<vscale x 4 x i32*> %ptrs, <
   ret <vscale x 4 x i32> %v
 }
 
+define <vscale x 4 x i32> @mgather_falsemask_nxv4i32(<vscale x 4 x i32*> %ptrs, <vscale x 4 x i32> %passthru) {
+; RV32-LABEL: mgather_falsemask_nxv4i32:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vmv2r.v v8, v10
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mgather_falsemask_nxv4i32:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vmv2r.v v8, v12
+; RV64-NEXT:    ret
+  %v = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0i32(<vscale x 4 x i32*> %ptrs, i32 4, <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i32> %passthru)
+  ret <vscale x 4 x i32> %v
+}
+
 declare <vscale x 8 x i32> @llvm.masked.gather.nxv8i32.nxv8p0i32(<vscale x 8 x i32*>, i32, <vscale x 8 x i1>, <vscale x 8 x i32>)
 
 define <vscale x 8 x i32> @mgather_nxv8i32(<vscale x 8 x i32*> %ptrs, <vscale x 8 x i1> %m, <vscale x 8 x i32> %passthru) {
@@ -937,6 +979,20 @@ define <vscale x 4 x i64> @mgather_truemask_nxv4i64(<vscale x 4 x i64*> %ptrs, <
   ret <vscale x 4 x i64> %v
 }
 
+define <vscale x 4 x i64> @mgather_falsemask_nxv4i64(<vscale x 4 x i64*> %ptrs, <vscale x 4 x i64> %passthru) {
+; RV32-LABEL: mgather_falsemask_nxv4i64:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vmv4r.v v8, v12
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mgather_falsemask_nxv4i64:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vmv4r.v v8, v12
+; RV64-NEXT:    ret
+  %v = call <vscale x 4 x i64> @llvm.masked.gather.nxv4i64.nxv4p0i64(<vscale x 4 x i64*> %ptrs, i32 8, <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i64> %passthru)
+  ret <vscale x 4 x i64> %v
+}
+
 declare <vscale x 8 x i64> @llvm.masked.gather.nxv8i64.nxv8p0i64(<vscale x 8 x i64*>, i32, <vscale x 8 x i1>, <vscale x 8 x i64>)
 
 define <vscale x 8 x i64> @mgather_nxv8i64(<vscale x 8 x i64*> %ptrs, <vscale x 8 x i1> %m, <vscale x 8 x i64> %passthru) {
@@ -1354,6 +1410,20 @@ define <vscale x 4 x half> @mgather_truemask_nxv4f16(<vscale x 4 x half*> %ptrs,
   ret <vscale x 4 x half> %v
 }
 
+define <vscale x 4 x half> @mgather_falsemask_nxv4f16(<vscale x 4 x half*> %ptrs, <vscale x 4 x half> %passthru) {
+; RV32-LABEL: mgather_falsemask_nxv4f16:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vmv1r.v v8, v10
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mgather_falsemask_nxv4f16:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vmv1r.v v8, v12
+; RV64-NEXT:    ret
+  %v = call <vscale x 4 x half> @llvm.masked.gather.nxv4f16.nxv4p0f16(<vscale x 4 x half*> %ptrs, i32 2, <vscale x 4 x i1> zeroinitializer, <vscale x 4 x half> %passthru)
+  ret <vscale x 4 x half> %v
+}
+
 declare <vscale x 8 x half> @llvm.masked.gather.nxv8f16.nxv8p0f16(<vscale x 8 x half*>, i32, <vscale x 8 x i1>, <vscale x 8 x half>)
 
 define <vscale x 8 x half> @mgather_nxv8f16(<vscale x 8 x half*> %ptrs, <vscale x 8 x i1> %m, <vscale x 8 x half> %passthru) {
@@ -1554,6 +1624,20 @@ define <vscale x 4 x float> @mgather_truemask_nxv4f32(<vscale x 4 x float*> %ptr
   ret <vscale x 4 x float> %v
 }
 
+define <vscale x 4 x float> @mgather_falsemask_nxv4f32(<vscale x 4 x float*> %ptrs, <vscale x 4 x float> %passthru) {
+; RV32-LABEL: mgather_falsemask_nxv4f32:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vmv2r.v v8, v10
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mgather_falsemask_nxv4f32:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vmv2r.v v8, v12
+; RV64-NEXT:    ret
+  %v = call <vscale x 4 x float> @llvm.masked.gather.nxv4f32.nxv4p0f32(<vscale x 4 x float*> %ptrs, i32 4, <vscale x 4 x i1> zeroinitializer, <vscale x 4 x float> %passthru)
+  ret <vscale x 4 x float> %v
+}
+
 declare <vscale x 8 x float> @llvm.masked.gather.nxv8f32.nxv8p0f32(<vscale x 8 x float*>, i32, <vscale x 8 x i1>, <vscale x 8 x float>)
 
 define <vscale x 8 x float> @mgather_nxv8f32(<vscale x 8 x float*> %ptrs, <vscale x 8 x i1> %m, <vscale x 8 x float> %passthru) {
@@ -1830,6 +1914,20 @@ define <vscale x 4 x double> @mgather_truemask_nxv4f64(<vscale x 4 x double*> %p
   ret <vscale x 4 x double> %v
 }
 
+define <vscale x 4 x double> @mgather_falsemask_nxv4f64(<vscale x 4 x double*> %ptrs, <vscale x 4 x double> %passthru) {
+; RV32-LABEL: mgather_falsemask_nxv4f64:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vmv4r.v v8, v12
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mgather_falsemask_nxv4f64:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vmv4r.v v8, v12
+; RV64-NEXT:    ret
+  %v = call <vscale x 4 x double> @llvm.masked.gather.nxv4f64.nxv4p0f64(<vscale x 4 x double*> %ptrs, i32 8, <vscale x 4 x i1> zeroinitializer, <vscale x 4 x double> %passthru)
+  ret <vscale x 4 x double> %v
+}
+
 declare <vscale x 8 x double> @llvm.masked.gather.nxv8f64.nxv8p0f64(<vscale x 8 x double*>, i32, <vscale x 8 x i1>, <vscale x 8 x double>)
 
 define <vscale x 8 x double> @mgather_nxv8f64(<vscale x 8 x double*> %ptrs, <vscale x 8 x i1> %m, <vscale x 8 x double> %passthru) {

diff  --git a/llvm/test/CodeGen/RISCV/rvv/mscatter-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/mscatter-sdnode.ll
index 424ea2f90458..57a9e0019f7a 100644
--- a/llvm/test/CodeGen/RISCV/rvv/mscatter-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/mscatter-sdnode.ll
@@ -145,6 +145,18 @@ define void @mscatter_truemask_nxv4i8(<vscale x 4 x i8> %val, <vscale x 4 x i8*>
   ret void
 }
 
+define void @mscatter_falsemask_nxv4i8(<vscale x 4 x i8> %val, <vscale x 4 x i8*> %ptrs) {
+; RV32-LABEL: mscatter_falsemask_nxv4i8:
+; RV32:       # %bb.0:
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mscatter_falsemask_nxv4i8:
+; RV64:       # %bb.0:
+; RV64-NEXT:    ret
+  call void @llvm.masked.scatter.nxv4i8.nxv4p0i8(<vscale x 4 x i8> %val, <vscale x 4 x i8*> %ptrs, i32 1, <vscale x 4 x i1> zeroinitializer)
+  ret void
+}
+
 declare void @llvm.masked.scatter.nxv8i8.nxv8p0i8(<vscale x 8 x i8>, <vscale x 8 x i8*>, i32, <vscale x 8 x i1>)
 
 define void @mscatter_nxv8i8(<vscale x 8 x i8> %val, <vscale x 8 x i8*> %ptrs, <vscale x 8 x i1> %m) {
@@ -298,6 +310,18 @@ define void @mscatter_truemask_nxv4i16(<vscale x 4 x i16> %val, <vscale x 4 x i1
   ret void
 }
 
+define void @mscatter_falsemask_nxv4i16(<vscale x 4 x i16> %val, <vscale x 4 x i16*> %ptrs) {
+; RV32-LABEL: mscatter_falsemask_nxv4i16:
+; RV32:       # %bb.0:
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mscatter_falsemask_nxv4i16:
+; RV64:       # %bb.0:
+; RV64-NEXT:    ret
+  call void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16> %val, <vscale x 4 x i16*> %ptrs, i32 2, <vscale x 4 x i1> zeroinitializer)
+  ret void
+}
+
 declare void @llvm.masked.scatter.nxv8i16.nxv8p0i16(<vscale x 8 x i16>, <vscale x 8 x i16*>, i32, <vscale x 8 x i1>)
 
 define void @mscatter_nxv8i16(<vscale x 8 x i16> %val, <vscale x 8 x i16*> %ptrs, <vscale x 8 x i1> %m) {
@@ -501,6 +525,18 @@ define void @mscatter_truemask_nxv4i32(<vscale x 4 x i32> %val, <vscale x 4 x i3
   ret void
 }
 
+define void @mscatter_falsemask_nxv4i32(<vscale x 4 x i32> %val, <vscale x 4 x i32*> %ptrs) {
+; RV32-LABEL: mscatter_falsemask_nxv4i32:
+; RV32:       # %bb.0:
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mscatter_falsemask_nxv4i32:
+; RV64:       # %bb.0:
+; RV64-NEXT:    ret
+  call void @llvm.masked.scatter.nxv4i32.nxv4p0i32(<vscale x 4 x i32> %val, <vscale x 4 x i32*> %ptrs, i32 4, <vscale x 4 x i1> zeroinitializer)
+  ret void
+}
+
 declare void @llvm.masked.scatter.nxv8i32.nxv8p0i32(<vscale x 8 x i32>, <vscale x 8 x i32*>, i32, <vscale x 8 x i1>)
 
 define void @mscatter_nxv8i32(<vscale x 8 x i32> %val, <vscale x 8 x i32*> %ptrs, <vscale x 8 x i1> %m) {
@@ -748,6 +784,18 @@ define void @mscatter_truemask_nxv4i64(<vscale x 4 x i64> %val, <vscale x 4 x i6
   ret void
 }
 
+define void @mscatter_falsemask_nxv4i64(<vscale x 4 x i64> %val, <vscale x 4 x i64*> %ptrs) {
+; RV32-LABEL: mscatter_falsemask_nxv4i64:
+; RV32:       # %bb.0:
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mscatter_falsemask_nxv4i64:
+; RV64:       # %bb.0:
+; RV64-NEXT:    ret
+  call void @llvm.masked.scatter.nxv4i64.nxv4p0i64(<vscale x 4 x i64> %val, <vscale x 4 x i64*> %ptrs, i32 8, <vscale x 4 x i1> zeroinitializer)
+  ret void
+}
+
 declare void @llvm.masked.scatter.nxv8i64.nxv8p0i64(<vscale x 8 x i64>, <vscale x 8 x i64*>, i32, <vscale x 8 x i1>)
 
 define void @mscatter_nxv8i64(<vscale x 8 x i64> %val, <vscale x 8 x i64*> %ptrs, <vscale x 8 x i1> %m) {
@@ -1054,6 +1102,18 @@ define void @mscatter_truemask_nxv4f16(<vscale x 4 x half> %val, <vscale x 4 x h
   ret void
 }
 
+define void @mscatter_falsemask_nxv4f16(<vscale x 4 x half> %val, <vscale x 4 x half*> %ptrs) {
+; RV32-LABEL: mscatter_falsemask_nxv4f16:
+; RV32:       # %bb.0:
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mscatter_falsemask_nxv4f16:
+; RV64:       # %bb.0:
+; RV64-NEXT:    ret
+  call void @llvm.masked.scatter.nxv4f16.nxv4p0f16(<vscale x 4 x half> %val, <vscale x 4 x half*> %ptrs, i32 2, <vscale x 4 x i1> zeroinitializer)
+  ret void
+}
+
 declare void @llvm.masked.scatter.nxv8f16.nxv8p0f16(<vscale x 8 x half>, <vscale x 8 x half*>, i32, <vscale x 8 x i1>)
 
 define void @mscatter_nxv8f16(<vscale x 8 x half> %val, <vscale x 8 x half*> %ptrs, <vscale x 8 x i1> %m) {
@@ -1238,6 +1298,18 @@ define void @mscatter_truemask_nxv4f32(<vscale x 4 x float> %val, <vscale x 4 x
   ret void
 }
 
+define void @mscatter_falsemask_nxv4f32(<vscale x 4 x float> %val, <vscale x 4 x float*> %ptrs) {
+; RV32-LABEL: mscatter_falsemask_nxv4f32:
+; RV32:       # %bb.0:
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mscatter_falsemask_nxv4f32:
+; RV64:       # %bb.0:
+; RV64-NEXT:    ret
+  call void @llvm.masked.scatter.nxv4f32.nxv4p0f32(<vscale x 4 x float> %val, <vscale x 4 x float*> %ptrs, i32 4, <vscale x 4 x i1> zeroinitializer)
+  ret void
+}
+
 declare void @llvm.masked.scatter.nxv8f32.nxv8p0f32(<vscale x 8 x float>, <vscale x 8 x float*>, i32, <vscale x 8 x i1>)
 
 define void @mscatter_nxv8f32(<vscale x 8 x float> %val, <vscale x 8 x float*> %ptrs, <vscale x 8 x i1> %m) {
@@ -1485,6 +1557,18 @@ define void @mscatter_truemask_nxv4f64(<vscale x 4 x double> %val, <vscale x 4 x
   ret void
 }
 
+define void @mscatter_falsemask_nxv4f64(<vscale x 4 x double> %val, <vscale x 4 x double*> %ptrs) {
+; RV32-LABEL: mscatter_falsemask_nxv4f64:
+; RV32:       # %bb.0:
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: mscatter_falsemask_nxv4f64:
+; RV64:       # %bb.0:
+; RV64-NEXT:    ret
+  call void @llvm.masked.scatter.nxv4f64.nxv4p0f64(<vscale x 4 x double> %val, <vscale x 4 x double*> %ptrs, i32 8, <vscale x 4 x i1> zeroinitializer)
+  ret void
+}
+
 declare void @llvm.masked.scatter.nxv8f64.nxv8p0f64(<vscale x 8 x double>, <vscale x 8 x double*>, i32, <vscale x 8 x i1>)
 
 define void @mscatter_nxv8f64(<vscale x 8 x double> %val, <vscale x 8 x double*> %ptrs, <vscale x 8 x i1> %m) {


        


More information about the llvm-commits mailing list