[llvm] AMDGPU: Generalize truncate of shift of cast build_vector combine (PR #125617)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 3 17:40:47 PST 2025
https://github.com/arsenm created https://github.com/llvm/llvm-project/pull/125617
Previously we only handled cases that looked like the high element
extract of a 64-bit shift. Generalize this to handle any multiple
indexing. I was hoping this would help avoid some regressions,
but it did not. It does however reduce the number of steps the DAG
takes to process these cases.
NFC-ish, I have yet to find an example where this changes the
final output.
>From 8feca1b6c1af1308e443c9ca00db918b09bcb43f Mon Sep 17 00:00:00 2001
From: Matt Arsenault <Matthew.Arsenault at amd.com>
Date: Mon, 27 Jan 2025 13:36:37 +0700
Subject: [PATCH] AMDGPU: Generalize truncate of shift of cast build_vector
combine
Previously we only handled cases that looked like the high element
extract of a 64-bit shift. Generalize this to handle any multiple
indexing. I was hoping this would help avoid some regressions,
but it did not. It does however reduce the number of steps the DAG
takes to process these cases.
NFC-ish, I have yet to find an example where this changes the
final output.
---
llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp | 25 ++--
...truncate-lshr-cast-build-vector-combine.ll | 140 ++++++++++++++++++
2 files changed, 154 insertions(+), 11 deletions(-)
create mode 100644 llvm/test/CodeGen/AMDGPU/truncate-lshr-cast-build-vector-combine.ll
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index cca9fa72d0ca534..792e17eeedab141 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -4217,18 +4217,21 @@ SDValue AMDGPUTargetLowering::performTruncateCombine(
// trunc (srl (bitcast (build_vector x, y))), 16 -> trunc (bitcast y)
if (Src.getOpcode() == ISD::SRL && !VT.isVector()) {
if (auto *K = isConstOrConstSplat(Src.getOperand(1))) {
- if (2 * K->getZExtValue() == Src.getValueType().getScalarSizeInBits()) {
- SDValue BV = stripBitcast(Src.getOperand(0));
- if (BV.getOpcode() == ISD::BUILD_VECTOR &&
- BV.getValueType().getVectorNumElements() == 2) {
- SDValue SrcElt = BV.getOperand(1);
- EVT SrcEltVT = SrcElt.getValueType();
- if (SrcEltVT.isFloatingPoint()) {
- SrcElt = DAG.getNode(ISD::BITCAST, SL,
- SrcEltVT.changeTypeToInteger(), SrcElt);
+ SDValue BV = stripBitcast(Src.getOperand(0));
+ if (BV.getOpcode() == ISD::BUILD_VECTOR) {
+ EVT SrcEltVT = BV.getOperand(0).getValueType();
+ unsigned SrcEltSize = SrcEltVT.getSizeInBits();
+ unsigned BitIndex = K->getZExtValue();
+ unsigned PartIndex = BitIndex / SrcEltSize;
+
+ if (PartIndex * SrcEltSize == BitIndex &&
+ PartIndex < BV.getNumOperands()) {
+ if (SrcEltVT.getSizeInBits() == VT.getSizeInBits()) {
+ SDValue SrcElt =
+ DAG.getNode(ISD::BITCAST, SL, SrcEltVT.changeTypeToInteger(),
+ BV.getOperand(PartIndex));
+ return DAG.getNode(ISD::TRUNCATE, SL, VT, SrcElt);
}
-
- return DAG.getNode(ISD::TRUNCATE, SL, VT, SrcElt);
}
}
}
diff --git a/llvm/test/CodeGen/AMDGPU/truncate-lshr-cast-build-vector-combine.ll b/llvm/test/CodeGen/AMDGPU/truncate-lshr-cast-build-vector-combine.ll
new file mode 100644
index 000000000000000..1c3091f6b8d3bf1
--- /dev/null
+++ b/llvm/test/CodeGen/AMDGPU/truncate-lshr-cast-build-vector-combine.ll
@@ -0,0 +1,140 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx900 < %s | FileCheck %s
+
+; extract element 0 as shift
+define i32 @cast_v4i32_to_i128_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_trunc_i32:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %trunc = trunc i128 %bigint to i32
+ ret i32 %trunc
+}
+
+; extract element 1 as shift
+define i32 @cast_v4i32_to_i128_lshr_32_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_32_trunc_i32:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_mov_b32_e32 v0, v1
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %srl = lshr i128 %bigint, 32
+ %trunc = trunc i128 %srl to i32
+ ret i32 %trunc
+}
+
+; extract element 2 as shift
+define i32 @cast_v4i32_to_i128_lshr_64_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_64_trunc_i32:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_mov_b32_e32 v0, v2
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %srl = lshr i128 %bigint, 64
+ %trunc = trunc i128 %srl to i32
+ ret i32 %trunc
+}
+
+; extract element 3 as shift
+define i32 @cast_v4i32_to_i128_lshr_96_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_96_trunc_i32:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_mov_b32_e32 v0, v3
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %srl = lshr i128 %bigint, 96
+ %trunc = trunc i128 %srl to i32
+ ret i32 %trunc
+}
+
+; Shift not aligned to element, not a simple extract
+define i32 @cast_v4i32_to_i128_lshr_33_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_33_trunc_i32:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_alignbit_b32 v0, v2, v1, 1
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %srl = lshr i128 %bigint, 33
+ %trunc = trunc i128 %srl to i32
+ ret i32 %trunc
+}
+
+; extract misaligned element
+define i32 @cast_v4i32_to_i128_lshr_31_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_31_trunc_i32:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_alignbit_b32 v0, v1, v0, 31
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %srl = lshr i128 %bigint, 31
+ %trunc = trunc i128 %srl to i32
+ ret i32 %trunc
+}
+
+; extract misaligned element
+define i32 @cast_v4i32_to_i128_lshr_48_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_48_trunc_i32:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: s_mov_b32 s4, 0x1000706
+; CHECK-NEXT: v_perm_b32 v0, v1, v2, s4
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %srl = lshr i128 %bigint, 48
+ %trunc = trunc i128 %srl to i32
+ ret i32 %trunc
+}
+
+; extract elements 1 and 2 with shift
+define i64 @cast_v4i32_to_i128_lshr_32_trunc_i64(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_32_trunc_i64:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_mov_b32_e32 v0, v1
+; CHECK-NEXT: v_mov_b32_e32 v1, v2
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %srl = lshr i128 %bigint, 32
+ %trunc = trunc i128 %srl to i64
+ ret i64 %trunc
+}
+
+; extract elements 2 and 3 with shift
+define i64 @cast_v4i32_to_i128_lshr_64_trunc_i64(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_64_trunc_i64:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_mov_b32_e32 v1, v3
+; CHECK-NEXT: v_mov_b32_e32 v0, v2
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %bigint = bitcast <4 x i32> %arg to i128
+ %srl = lshr i128 %bigint, 64
+ %trunc = trunc i128 %srl to i64
+ ret i64 %trunc
+}
+
+; FIXME: We don't process this case because we see multiple bitcasts
+; before a 32-bit build_vector
+define i32 @build_vector_i16_to_shift(i16 %arg0, i16 %arg1, i16 %arg2, i16 %arg3) {
+; CHECK-LABEL: build_vector_i16_to_shift:
+; CHECK: ; %bb.0:
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: s_mov_b32 s4, 0x5040100
+; CHECK-NEXT: v_perm_b32 v0, v3, v2, s4
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+ %ins.0 = insertelement <4 x i16> poison, i16 %arg0, i32 0
+ %ins.1 = insertelement <4 x i16> %ins.0, i16 %arg1, i32 1
+ %ins.2 = insertelement <4 x i16> %ins.1, i16 %arg2, i32 2
+ %ins.3 = insertelement <4 x i16> %ins.2, i16 %arg3, i32 3
+
+ %cast = bitcast <4 x i16> %ins.3 to i64
+ %srl = lshr i64 %cast, 32
+ %trunc = trunc i64 %srl to i32
+ ret i32 %trunc
+}
More information about the llvm-commits
mailing list