[llvm] 7a0b897 - [DAGCombiner][SVE] Ensure MGATHER/MSCATTER addressing mode combines preserve index scaling
Paul Walker via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 29 04:43:38 PDT 2022
Author: Paul Walker
Date: 2022-04-29T12:35:16+01:00
New Revision: 7a0b897e8664d11481230a69a88fca2b2ee5f904
URL: https://github.com/llvm/llvm-project/commit/7a0b897e8664d11481230a69a88fca2b2ee5f904
DIFF: https://github.com/llvm/llvm-project/commit/7a0b897e8664d11481230a69a88fca2b2ee5f904.diff
LOG: [DAGCombiner][SVE] Ensure MGATHER/MSCATTER addressing mode combines preserve index scaling
refineUniformBase and selectGatherScatterAddrMode both attempt the
transformation:
base(0) + index(A+splat(B)) => base(B) + index(A)
However, this is only safe when index is not implicitly scaled.
Differential Revision: https://reviews.llvm.org/D123222
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 71618eb2bd7c..181ff00184b3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10426,14 +10426,19 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
TopHalf->isZero() ? RHS->getOperand(1) : LHS->getOperand(1));
}
-bool refineUniformBase(SDValue &BasePtr, SDValue &Index, SelectionDAG &DAG) {
+bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
+ SelectionDAG &DAG) {
if (!isNullConstant(BasePtr) || Index.getOpcode() != ISD::ADD)
return false;
+ // Only perform the transformation when existing operands can be reused.
+ if (IndexIsScaled)
+ return false;
+
// For now we check only the LHS of the add.
SDValue LHS = Index.getOperand(0);
SDValue SplatVal = DAG.getSplatValue(LHS);
- if (!SplatVal)
+ if (!SplatVal || SplatVal.getValueType() != BasePtr.getValueType())
return false;
BasePtr = SplatVal;
@@ -10481,7 +10486,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
return Chain;
- if (refineUniformBase(BasePtr, Index, DAG)) {
+ if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) {
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
return DAG.getMaskedScatter(
DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops,
@@ -10576,7 +10581,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
return CombineTo(N, PassThru, MGT->getChain());
- if (refineUniformBase(BasePtr, Index, DAG)) {
+ if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) {
SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
MGT->getMemoryVT(), DL, Ops,
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a5dadd112c9e..3f6a36fe49f4 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4656,10 +4656,10 @@ bool getGatherScatterIndexIsExtended(SDValue Index) {
// VECTOR + IMMEDIATE:
// getelementptr nullptr, <vscale x N x T> (splat(#x)) + %indices)
// -> getelementptr #x, <vscale x N x T> %indices
-void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index, EVT MemVT,
- unsigned &Opcode, bool IsGather,
- SelectionDAG &DAG) {
- if (!isNullConstant(BasePtr))
+void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index,
+ bool IsScaled, EVT MemVT, unsigned &Opcode,
+ bool IsGather, SelectionDAG &DAG) {
+ if (!isNullConstant(BasePtr) || IsScaled)
return;
// FIXME: This will not match for fixed vector type codegen as the nodes in
@@ -4789,7 +4789,7 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
Index = Index.getOperand(0);
unsigned Opcode = getGatherVecOpcode(IsScaled, IsSigned, IdxNeedsExtend);
- selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode,
+ selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
/*isGather=*/true, DAG);
if (ExtType == ISD::SEXTLOAD)
@@ -4898,7 +4898,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
Index = Index.getOperand(0);
unsigned Opcode = getScatterVecOpcode(IsScaled, IsSigned, NeedsExtend);
- selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode,
+ selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
/*isGather=*/false, DAG);
if (IsFixedLength) {
diff --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
index d06cc313ba53..4fdf4a106dbc 100644
--- a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
+++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
@@ -343,12 +343,13 @@ define <vscale x 2 x i64> @masked_gather_nxv2i64_const_with_vec_offsets(<vscale
ret <vscale x 2 x i64> %data
}
-; TODO: The generated code is wrong because we've lost the scaling applied to
-; %scalar_offset when it's used to calculate %ptrs.
define <vscale x 2 x i64> @masked_gather_nxv2i64_null_with_vec_plus_scalar_offsets(<vscale x 2 x i64> %vector_offsets, i64 %scalar_offset, <vscale x 2 x i1> %pg) #0 {
; CHECK-LABEL: masked_gather_nxv2i64_null_with_vec_plus_scalar_offsets:
; CHECK: // %bb.0:
-; CHECK-NEXT: ld1d { z0.d }, p0/z, [x0, z0.d, lsl #3]
+; CHECK-NEXT: mov x8, xzr
+; CHECK-NEXT: mov z1.d, x0
+; CHECK-NEXT: add z0.d, z0.d, z1.d
+; CHECK-NEXT: ld1d { z0.d }, p0/z, [x8, z0.d, lsl #3]
; CHECK-NEXT: ret
%scalar_offset.ins = insertelement <vscale x 2 x i64> undef, i64 %scalar_offset, i64 0
%scalar_offset.splat = shufflevector <vscale x 2 x i64> %scalar_offset.ins, <vscale x 2 x i64> undef, <vscale x 2 x i32> zeroinitializer
@@ -358,12 +359,11 @@ define <vscale x 2 x i64> @masked_gather_nxv2i64_null_with_vec_plus_scalar_offse
ret <vscale x 2 x i64> %data
}
-; TODO: The generated code is wrong because we've lost the scaling applied to
-; constant scalar offset (i.e. i64 1) when it's used to calculate %ptrs.
define <vscale x 2 x i64> @masked_gather_nxv2i64_null_with__vec_plus_imm_offsets(<vscale x 2 x i64> %vector_offsets, <vscale x 2 x i1> %pg) #0 {
; CHECK-LABEL: masked_gather_nxv2i64_null_with__vec_plus_imm_offsets:
; CHECK: // %bb.0:
-; CHECK-NEXT: mov w8, #1
+; CHECK-NEXT: mov x8, xzr
+; CHECK-NEXT: add z0.d, z0.d, #1 // =0x1
; CHECK-NEXT: ld1d { z0.d }, p0/z, [x8, z0.d, lsl #3]
; CHECK-NEXT: ret
%scalar_offset.ins = insertelement <vscale x 2 x i64> undef, i64 1, i64 0
@@ -425,12 +425,13 @@ define void @masked_scatter_nxv2i64_const_with_vec_offsets(<vscale x 2 x i64> %v
ret void
}
-; TODO: The generated code is wrong because we've lost the scaling applied to
-; %scalar_offset when it's used to calculate %ptrs.
define void @masked_scatter_nxv2i64_null_with_vec_plus_scalar_offsets(<vscale x 2 x i64> %vector_offsets, i64 %scalar_offset, <vscale x 2 x i1> %pg, <vscale x 2 x i64> %data) #0 {
; CHECK-LABEL: masked_scatter_nxv2i64_null_with_vec_plus_scalar_offsets:
; CHECK: // %bb.0:
-; CHECK-NEXT: st1d { z1.d }, p0, [x0, z0.d, lsl #3]
+; CHECK-NEXT: mov x8, xzr
+; CHECK-NEXT: mov z2.d, x0
+; CHECK-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEXT: st1d { z1.d }, p0, [x8, z0.d, lsl #3]
; CHECK-NEXT: ret
%scalar_offset.ins = insertelement <vscale x 2 x i64> undef, i64 %scalar_offset, i64 0
%scalar_offset.splat = shufflevector <vscale x 2 x i64> %scalar_offset.ins, <vscale x 2 x i64> undef, <vscale x 2 x i32> zeroinitializer
@@ -440,12 +441,11 @@ define void @masked_scatter_nxv2i64_null_with_vec_plus_scalar_offsets(<vscale x
ret void
}
-; TODO: The generated code is wrong because we've lost the scaling applied to
-; constant scalar offset (i.e. i64 1) when it's used to calculate %ptrs.
define void @masked_scatter_nxv2i64_null_with__vec_plus_imm_offsets(<vscale x 2 x i64> %vector_offsets, <vscale x 2 x i1> %pg, <vscale x 2 x i64> %data) #0 {
; CHECK-LABEL: masked_scatter_nxv2i64_null_with__vec_plus_imm_offsets:
; CHECK: // %bb.0:
-; CHECK-NEXT: mov w8, #1
+; CHECK-NEXT: mov x8, xzr
+; CHECK-NEXT: add z0.d, z0.d, #1 // =0x1
; CHECK-NEXT: st1d { z1.d }, p0, [x8, z0.d, lsl #3]
; CHECK-NEXT: ret
%scalar_offset.ins = insertelement <vscale x 2 x i64> undef, i64 1, i64 0
More information about the llvm-commits
mailing list