[llvm] 46525fe - [DAGCombine] Check both forms of a commutative transform

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 22 12:22:31 PDT 2022


Author: Philip Reames
Date: 2022-09-22T12:21:47-07:00
New Revision: 46525fee812343b5e603a77893dfe2983768ca56

URL: https://github.com/llvm/llvm-project/commit/46525fee812343b5e603a77893dfe2983768ca56
DIFF: https://github.com/llvm/llvm-project/commit/46525fee812343b5e603a77893dfe2983768ca56.diff

LOG: [DAGCombine] Check both forms of a commutative transform

The transform to fold an add into the base of a scatter/gather was only checking to see if the LHS was a splat.  Included test change indicates that splats are not canonicalized to LHS, and that we need to check both sides.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 0eb1265dcc617..2dad2f25c8318 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10674,15 +10674,19 @@ bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
   if (IndexIsScaled)
     return false;
 
-  // For now we check only the LHS of the add.
-  SDValue LHS = Index.getOperand(0);
-  SDValue SplatVal = DAG.getSplatValue(LHS);
-  if (!SplatVal || SplatVal.getValueType() != BasePtr.getValueType())
-    return false;
-
-  BasePtr = SplatVal;
-  Index = Index.getOperand(1);
-  return true;
+  if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(0));
+      SplatVal && SplatVal.getValueType() == BasePtr.getValueType()) {
+    BasePtr = SplatVal;
+    Index = Index.getOperand(1);
+    return true;
+  }
+  if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(1));
+      SplatVal && SplatVal.getValueType() == BasePtr.getValueType()) {
+    BasePtr = SplatVal;
+    Index = Index.getOperand(0);
+    return true;
+  }
+  return false;
 }
 
 // Fold sext/zext of index into index type.

diff  --git a/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll b/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll
index 9c90f8acaed80..cc7b2815897fa 100644
--- a/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll
@@ -14,10 +14,9 @@ define void @complex_gep(ptr %p, <vscale x 2 x i64> %vec.ind, <vscale x 2 x i1>
 ; RV32-NEXT:    vnsrl.wi v11, v8, 0
 ; RV32-NEXT:    li a0, 48
 ; RV32-NEXT:    vmadd.vx v11, a0, v10
+; RV32-NEXT:    vmv.v.i v8, 0
 ; RV32-NEXT:    li a0, 28
-; RV32-NEXT:    vadd.vx v8, v11, a0
-; RV32-NEXT:    vmv.v.i v9, 0
-; RV32-NEXT:    vsoxei32.v v9, (zero), v8, v0.t
+; RV32-NEXT:    vsoxei32.v v8, (a0), v11, v0.t
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: complex_gep:
@@ -26,11 +25,10 @@ define void @complex_gep(ptr %p, <vscale x 2 x i64> %vec.ind, <vscale x 2 x i1>
 ; RV64-NEXT:    vmv.v.x v10, a0
 ; RV64-NEXT:    li a0, 56
 ; RV64-NEXT:    vmacc.vx v10, a0, v8
-; RV64-NEXT:    li a0, 32
-; RV64-NEXT:    vadd.vx v8, v10, a0
 ; RV64-NEXT:    vsetvli zero, zero, e32, m1, ta, mu
-; RV64-NEXT:    vmv.v.i v10, 0
-; RV64-NEXT:    vsoxei64.v v10, (zero), v8, v0.t
+; RV64-NEXT:    vmv.v.i v8, 0
+; RV64-NEXT:    li a0, 32
+; RV64-NEXT:    vsoxei64.v v8, (a0), v10, v0.t
 ; RV64-NEXT:    ret
   %gep = getelementptr inbounds %struct, ptr %p, <vscale x 2 x i64> %vec.ind, i32 5
   call void @llvm.masked.scatter.nxv2i32.nxv2p0(<vscale x 2 x i32> zeroinitializer, <vscale x 2 x ptr> %gep, i32 8, <vscale x 2 x i1> %m)


        


More information about the llvm-commits mailing list