[llvm] [DAG][X86] Fold mgather/mscatter/etc with splat index (PR #65980)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 11 09:54:48 PDT 2023


https://github.com/preames created https://github.com/llvm/llvm-project/pull/65980:

A splat index means the operation is reading from (writing to) the same memory location.  Generally, zero is the cheapest value to splat.  As such, we'd prefer to add the splatted value to the base, and use a constant zero as the index operand.

>From 14aae175adcaff5d5a0f23f39056fd6c067cbd90 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Mon, 11 Sep 2023 09:45:45 -0700
Subject: [PATCH] [DAG][X86] Fold mgather/mscatter/etc with splat index

A splat index means the operation is reading from (writing to) the same memory location.  Generally, zero is the cheapest value to splat.  As such, we'd prefer to add the splatted value to the base, and use a constant zero as the index operand.
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 17 +++++++++++--
 llvm/test/CodeGen/X86/masked_gather.ll        | 24 +++++++++----------
 2 files changed, 26 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d917e8c00c4f92a..421d3f076afaf4c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -11637,8 +11637,6 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
 
 bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
                        SelectionDAG &DAG, const SDLoc &DL) {
-  if (Index.getOpcode() != ISD::ADD)
-    return false;
 
   // Only perform the transformation when existing operands can be reused.
   if (IndexIsScaled)
@@ -11648,6 +11646,21 @@ bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
     return false;
 
   EVT VT = BasePtr.getValueType();
+
+  if (SDValue SplatVal = DAG.getSplatValue(Index);
+      SplatVal && !isNullConstant(SplatVal) &&
+      SplatVal.getValueType() == VT) {
+    if (isNullConstant(BasePtr))
+      BasePtr = SplatVal;
+    else
+      BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
+    Index = DAG.getSplat(Index.getValueType(), DL, DAG.getConstant(0, DL, VT));
+    return true;
+  }
+
+  if (Index.getOpcode() != ISD::ADD)
+    return false;
+
   if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(0));
       SplatVal && SplatVal.getValueType() == VT) {
     if (isNullConstant(BasePtr))
diff --git a/llvm/test/CodeGen/X86/masked_gather.ll b/llvm/test/CodeGen/X86/masked_gather.ll
index 52d2187b50d6109..559a7ec0930b994 100644
--- a/llvm/test/CodeGen/X86/masked_gather.ll
+++ b/llvm/test/CodeGen/X86/masked_gather.ll
@@ -1747,29 +1747,27 @@ define <8 x i32> @gather_v8i32_v8i32(<8 x i32> %trigger) {
 ; AVX512F-NEXT:    vptestnmd %zmm0, %zmm0, %k0
 ; AVX512F-NEXT:    kshiftlw $8, %k0, %k0
 ; AVX512F-NEXT:    kshiftrw $8, %k0, %k1
-; AVX512F-NEXT:    vpbroadcastd {{.*#+}} zmm0 = [12,12,12,12,12,12,12,12,12,12,12,12,12,12,12,12]
+; AVX512F-NEXT:    vpxor %xmm0, %xmm0, %xmm0
 ; AVX512F-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX512F-NEXT:    vpxor %xmm2, %xmm2, %xmm2
 ; AVX512F-NEXT:    kmovw %k1, %k2
-; AVX512F-NEXT:    vpgatherdd c(,%zmm0), %zmm2 {%k2}
-; AVX512F-NEXT:    vpbroadcastd {{.*#+}} zmm0 = [28,28,28,28,28,28,28,28,28,28,28,28,28,28,28,28]
-; AVX512F-NEXT:    vpgatherdd c(,%zmm0), %zmm1 {%k1}
-; AVX512F-NEXT:    vpaddd %ymm1, %ymm2, %ymm0
-; AVX512F-NEXT:    vpaddd %ymm1, %ymm0, %ymm0
+; AVX512F-NEXT:    vpgatherdd c+12(,%zmm0), %zmm1 {%k2}
+; AVX512F-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX512F-NEXT:    vpgatherdd c+28(,%zmm0), %zmm2 {%k1}
+; AVX512F-NEXT:    vpaddd %ymm2, %ymm1, %ymm0
+; AVX512F-NEXT:    vpaddd %ymm2, %ymm0, %ymm0
 ; AVX512F-NEXT:    retq
 ;
 ; AVX512VL-LABEL: gather_v8i32_v8i32:
 ; AVX512VL:       # %bb.0:
 ; AVX512VL-NEXT:    vptestnmd %ymm0, %ymm0, %k1
 ; AVX512VL-NEXT:    vpxor %xmm0, %xmm0, %xmm0
-; AVX512VL-NEXT:    vpbroadcastd {{.*#+}} ymm1 = [12,12,12,12,12,12,12,12]
 ; AVX512VL-NEXT:    kmovw %k1, %k2
+; AVX512VL-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX512VL-NEXT:    vpgatherdd c+12(,%ymm0), %ymm1 {%k2}
 ; AVX512VL-NEXT:    vpxor %xmm2, %xmm2, %xmm2
-; AVX512VL-NEXT:    vpgatherdd c(,%ymm1), %ymm2 {%k2}
-; AVX512VL-NEXT:    vpbroadcastd {{.*#+}} ymm1 = [28,28,28,28,28,28,28,28]
-; AVX512VL-NEXT:    vpgatherdd c(,%ymm1), %ymm0 {%k1}
-; AVX512VL-NEXT:    vpaddd %ymm0, %ymm2, %ymm1
-; AVX512VL-NEXT:    vpaddd %ymm0, %ymm1, %ymm0
+; AVX512VL-NEXT:    vpgatherdd c+28(,%ymm0), %ymm2 {%k1}
+; AVX512VL-NEXT:    vpaddd %ymm2, %ymm1, %ymm0
+; AVX512VL-NEXT:    vpaddd %ymm2, %ymm0, %ymm0
 ; AVX512VL-NEXT:    retq
   %1 = icmp eq <8 x i32> %trigger, zeroinitializer
   %2 = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> getelementptr (%struct.a, <8 x ptr> <ptr @c, ptr @c, ptr @c, ptr @c, ptr @c, ptr @c, ptr @c, ptr @c>, <8 x i64> zeroinitializer, i32 0, <8 x i64> <i64 3, i64 3, i64 3, i64 3, i64 3, i64 3, i64 3, i64 3>), i32 4, <8 x i1> %1, <8 x i32> undef)



More information about the llvm-commits mailing list