[llvm] 077e0c1 - AMDGPU: Generalize truncate of shift of cast build_vector combine (#125617)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 3 20:46:34 PST 2025
Author: Matt Arsenault
Date: 2025-02-04T11:46:30+07:00
New Revision: 077e0c134a31cc16c432ce685458b1de80bfbf84
URL: https://github.com/llvm/llvm-project/commit/077e0c134a31cc16c432ce685458b1de80bfbf84
DIFF: https://github.com/llvm/llvm-project/commit/077e0c134a31cc16c432ce685458b1de80bfbf84.diff
LOG: AMDGPU: Generalize truncate of shift of cast build_vector combine (#125617)
Previously we only handled cases that looked like the high element
extract of a 64-bit shift. Generalize this to handle any multiple
indexing. I was hoping this would help avoid some regressions,
but it did not. It does however reduce the number of steps the DAG
takes to process these cases.
NFC-ish, I have yet to find an example where this changes the
final output.
Added:
llvm/test/CodeGen/AMDGPU/truncate-lshr-cast-build-vector-combine.ll
Modified:
llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index cca9fa72d0ca53..792e17eeedab14 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -4217,18 +4217,21 @@ SDValue AMDGPUTargetLowering::performTruncateCombine(
// trunc (srl (bitcast (build_vector x, y))), 16 -> trunc (bitcast y)
if (Src.getOpcode() == ISD::SRL && !VT.isVector()) {
if (auto *K = isConstOrConstSplat(Src.getOperand(1))) {
- if (2 * K->getZExtValue() == Src.getValueType().getScalarSizeInBits()) {
- SDValue BV = stripBitcast(Src.getOperand(0));
- if (BV.getOpcode() == ISD::BUILD_VECTOR &&
- BV.getValueType().getVectorNumElements() == 2) {
- SDValue SrcElt = BV.getOperand(1);
- EVT SrcEltVT = SrcElt.getValueType();
- if (SrcEltVT.isFloatingPoint()) {
- SrcElt = DAG.getNode(ISD::BITCAST, SL,
- SrcEltVT.changeTypeToInteger(), SrcElt);
+ SDValue BV = stripBitcast(Src.getOperand(0));
+ if (BV.getOpcode() == ISD::BUILD_VECTOR) {
+ EVT SrcEltVT = BV.getOperand(0).getValueType();
+ unsigned SrcEltSize = SrcEltVT.getSizeInBits();
+ unsigned BitIndex = K->getZExtValue();
+ unsigned PartIndex = BitIndex / SrcEltSize;
+
+ if (PartIndex * SrcEltSize == BitIndex &&
+ PartIndex < BV.getNumOperands()) {
+ if (SrcEltVT.getSizeInBits() == VT.getSizeInBits()) {
+ SDValue SrcElt =
+ DAG.getNode(ISD::BITCAST, SL, SrcEltVT.changeTypeToInteger(),
+ BV.getOperand(PartIndex));
+ return DAG.getNode(ISD::TRUNCATE, SL, VT, SrcElt);
}
-
- return DAG.getNode(ISD::TRUNCATE, SL, VT, SrcElt);
}
}
}
diff --git a/llvm/test/CodeGen/AMDGPU/truncate-lshr-cast-build-vector-combine.ll b/llvm/test/CodeGen/AMDGPU/truncate-lshr-cast-build-vector-combine.ll
new file mode 100644
index 00000000000000..1c3091f6b8d3bf
--- /dev/null
+++ b/llvm/test/CodeGen/AMDGPU/truncate-lshr-cast-build-vector-combine.ll
@@ -0,0 +1,140 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx900 < %s | FileCheck %s
+
+; extract element 0 as shift
+define i32 @cast_v4i32_to_i128_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_trunc_i32:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %trunc = trunc i128 %bigint to i32
+ ret i32 %trunc
+}
+
+; extract element 1 as shift
+define i32 @cast_v4i32_to_i128_lshr_32_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_32_trunc_i32:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_mov_b32_e32 v0, v1
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %srl = lshr i128 %bigint, 32
+ %trunc = trunc i128 %srl to i32
+ ret i32 %trunc
+}
+
+; extract element 2 as shift
+define i32 @cast_v4i32_to_i128_lshr_64_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_64_trunc_i32:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_mov_b32_e32 v0, v2
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %srl = lshr i128 %bigint, 64
+ %trunc = trunc i128 %srl to i32
+ ret i32 %trunc
+}
+
+; extract element 3 as shift
+define i32 @cast_v4i32_to_i128_lshr_96_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_96_trunc_i32:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_mov_b32_e32 v0, v3
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %srl = lshr i128 %bigint, 96
+ %trunc = trunc i128 %srl to i32
+ ret i32 %trunc
+}
+
+; Shift not aligned to element, not a simple extract
+define i32 @cast_v4i32_to_i128_lshr_33_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_33_trunc_i32:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_alignbit_b32 v0, v2, v1, 1
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %srl = lshr i128 %bigint, 33
+ %trunc = trunc i128 %srl to i32
+ ret i32 %trunc
+}
+
+; extract misaligned element
+define i32 @cast_v4i32_to_i128_lshr_31_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_31_trunc_i32:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_alignbit_b32 v0, v1, v0, 31
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %srl = lshr i128 %bigint, 31
+ %trunc = trunc i128 %srl to i32
+ ret i32 %trunc
+}
+
+; extract misaligned element
+define i32 @cast_v4i32_to_i128_lshr_48_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_48_trunc_i32:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: s_mov_b32 s4, 0x1000706
+; CHECK-NEXT: v_perm_b32 v0, v1, v2, s4
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %srl = lshr i128 %bigint, 48
+ %trunc = trunc i128 %srl to i32
+ ret i32 %trunc
+}
+
+; extract elements 1 and 2 with shift
+define i64 @cast_v4i32_to_i128_lshr_32_trunc_i64(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_32_trunc_i64:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_mov_b32_e32 v0, v1
+; CHECK-NEXT: v_mov_b32_e32 v1, v2
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %srl = lshr i128 %bigint, 32
+ %trunc = trunc i128 %srl to i64
+ ret i64 %trunc
+}
+
+; extract elements 2 and 3 with shift
+define i64 @cast_v4i32_to_i128_lshr_64_trunc_i64(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_64_trunc_i64:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_mov_b32_e32 v1, v3
+; CHECK-NEXT: v_mov_b32_e32 v0, v2
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %srl = lshr i128 %bigint, 64
+ %trunc = trunc i128 %srl to i64
+ ret i64 %trunc
+}
+
+; FIXME: We don't process this case because we see multiple bitcasts
+; before a 32-bit build_vector
+define i32 @build_vector_i16_to_shift(i16 %arg0, i16 %arg1, i16 %arg2, i16 %arg3) {
+; CHECK-LABEL: build_vector_i16_to_shift:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: s_mov_b32 s4, 0x5040100
+; CHECK-NEXT: v_perm_b32 v0, v3, v2, s4
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %ins.0 = insertelement <4 x i16> poison, i16 %arg0, i32 0
+ %ins.1 = insertelement <4 x i16> %ins.0, i16 %arg1, i32 1
+ %ins.2 = insertelement <4 x i16> %ins.1, i16 %arg2, i32 2
+ %ins.3 = insertelement <4 x i16> %ins.2, i16 %arg3, i32 3
+
+ %cast = bitcast <4 x i16> %ins.3 to i64
+ %srl = lshr i64 %cast, 32
+ %trunc = trunc i64 %srl to i32
+ ret i32 %trunc
+}
More information about the llvm-commits
mailing list