[llvm] 2f005df - [DAG][X86] Fold mgather/mscatter/etc with splat index (#65980)

via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 13 09:26:33 PDT 2023


Author: Philip Reames
Date: 2023-09-13T09:26:30-07:00
New Revision: 2f005df066e07d93e3d6aa04748c158f883197b7

URL: https://github.com/llvm/llvm-project/commit/2f005df066e07d93e3d6aa04748c158f883197b7
DIFF: https://github.com/llvm/llvm-project/commit/2f005df066e07d93e3d6aa04748c158f883197b7.diff

LOG: [DAG][X86] Fold mgather/mscatter/etc with splat index (#65980)

A splat index means the operation is reading from (writing to) the same
memory location. Generally, zero is the cheapest value to splat. As
such, we'd prefer to add the splatted value to the base, and use a
constant zero as the index operand.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/X86/masked_gather.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index f574c746142f959..acac1fd1b2a515f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -11637,8 +11637,6 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
 
 bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
                        SelectionDAG &DAG, const SDLoc &DL) {
-  if (Index.getOpcode() != ISD::ADD)
-    return false;
 
   // Only perform the transformation when existing operands can be reused.
   if (IndexIsScaled)
@@ -11648,6 +11646,21 @@ bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
     return false;
 
   EVT VT = BasePtr.getValueType();
+
+  if (SDValue SplatVal = DAG.getSplatValue(Index);
+      SplatVal && !isNullConstant(SplatVal) &&
+      SplatVal.getValueType() == VT) {
+    if (isNullConstant(BasePtr))
+      BasePtr = SplatVal;
+    else
+      BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
+    Index = DAG.getSplat(Index.getValueType(), DL, DAG.getConstant(0, DL, VT));
+    return true;
+  }
+
+  if (Index.getOpcode() != ISD::ADD)
+    return false;
+
   if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(0));
       SplatVal && SplatVal.getValueType() == VT) {
     if (isNullConstant(BasePtr))

diff  --git a/llvm/test/CodeGen/X86/masked_gather.ll b/llvm/test/CodeGen/X86/masked_gather.ll
index 52d2187b50d6109..559a7ec0930b994 100644
--- a/llvm/test/CodeGen/X86/masked_gather.ll
+++ b/llvm/test/CodeGen/X86/masked_gather.ll
@@ -1747,29 +1747,27 @@ define <8 x i32> @gather_v8i32_v8i32(<8 x i32> %trigger) {
 ; AVX512F-NEXT:    vptestnmd %zmm0, %zmm0, %k0
 ; AVX512F-NEXT:    kshiftlw $8, %k0, %k0
 ; AVX512F-NEXT:    kshiftrw $8, %k0, %k1
-; AVX512F-NEXT:    vpbroadcastd {{.*#+}} zmm0 = [12,12,12,12,12,12,12,12,12,12,12,12,12,12,12,12]
+; AVX512F-NEXT:    vpxor %xmm0, %xmm0, %xmm0
 ; AVX512F-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX512F-NEXT:    vpxor %xmm2, %xmm2, %xmm2
 ; AVX512F-NEXT:    kmovw %k1, %k2
-; AVX512F-NEXT:    vpgatherdd c(,%zmm0), %zmm2 {%k2}
-; AVX512F-NEXT:    vpbroadcastd {{.*#+}} zmm0 = [28,28,28,28,28,28,28,28,28,28,28,28,28,28,28,28]
-; AVX512F-NEXT:    vpgatherdd c(,%zmm0), %zmm1 {%k1}
-; AVX512F-NEXT:    vpaddd %ymm1, %ymm2, %ymm0
-; AVX512F-NEXT:    vpaddd %ymm1, %ymm0, %ymm0
+; AVX512F-NEXT:    vpgatherdd c+12(,%zmm0), %zmm1 {%k2}
+; AVX512F-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX512F-NEXT:    vpgatherdd c+28(,%zmm0), %zmm2 {%k1}
+; AVX512F-NEXT:    vpaddd %ymm2, %ymm1, %ymm0
+; AVX512F-NEXT:    vpaddd %ymm2, %ymm0, %ymm0
 ; AVX512F-NEXT:    retq
 ;
 ; AVX512VL-LABEL: gather_v8i32_v8i32:
 ; AVX512VL:       # %bb.0:
 ; AVX512VL-NEXT:    vptestnmd %ymm0, %ymm0, %k1
 ; AVX512VL-NEXT:    vpxor %xmm0, %xmm0, %xmm0
-; AVX512VL-NEXT:    vpbroadcastd {{.*#+}} ymm1 = [12,12,12,12,12,12,12,12]
 ; AVX512VL-NEXT:    kmovw %k1, %k2
+; AVX512VL-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX512VL-NEXT:    vpgatherdd c+12(,%ymm0), %ymm1 {%k2}
 ; AVX512VL-NEXT:    vpxor %xmm2, %xmm2, %xmm2
-; AVX512VL-NEXT:    vpgatherdd c(,%ymm1), %ymm2 {%k2}
-; AVX512VL-NEXT:    vpbroadcastd {{.*#+}} ymm1 = [28,28,28,28,28,28,28,28]
-; AVX512VL-NEXT:    vpgatherdd c(,%ymm1), %ymm0 {%k1}
-; AVX512VL-NEXT:    vpaddd %ymm0, %ymm2, %ymm1
-; AVX512VL-NEXT:    vpaddd %ymm0, %ymm1, %ymm0
+; AVX512VL-NEXT:    vpgatherdd c+28(,%ymm0), %ymm2 {%k1}
+; AVX512VL-NEXT:    vpaddd %ymm2, %ymm1, %ymm0
+; AVX512VL-NEXT:    vpaddd %ymm2, %ymm0, %ymm0
 ; AVX512VL-NEXT:    retq
   %1 = icmp eq <8 x i32> %trigger, zeroinitializer
   %2 = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> getelementptr (%struct.a, <8 x ptr> <ptr @c, ptr @c, ptr @c, ptr @c, ptr @c, ptr @c, ptr @c, ptr @c>, <8 x i64> zeroinitializer, i32 0, <8 x i64> <i64 3, i64 3, i64 3, i64 3, i64 3, i64 3, i64 3, i64 3>), i32 4, <8 x i1> %1, <8 x i32> undef)


        


More information about the llvm-commits mailing list