[llvm] 077e0c1 - AMDGPU: Generalize truncate of shift of cast build_vector combine (#125617)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 3 20:46:34 PST 2025


Author: Matt Arsenault
Date: 2025-02-04T11:46:30+07:00
New Revision: 077e0c134a31cc16c432ce685458b1de80bfbf84

URL: https://github.com/llvm/llvm-project/commit/077e0c134a31cc16c432ce685458b1de80bfbf84
DIFF: https://github.com/llvm/llvm-project/commit/077e0c134a31cc16c432ce685458b1de80bfbf84.diff

LOG: AMDGPU: Generalize truncate of shift of cast build_vector combine (#125617)

Previously we only handled cases that looked like the high element
extract of a 64-bit shift. Generalize this to handle any multiple
indexing. I was hoping this would help avoid some regressions,
but it did not. It does however reduce the number of steps the DAG
takes to process these cases.

NFC-ish, I have yet to find an example where this changes the
final output.

Added: 
    llvm/test/CodeGen/AMDGPU/truncate-lshr-cast-build-vector-combine.ll

Modified: 
    llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index cca9fa72d0ca53..792e17eeedab14 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -4217,18 +4217,21 @@ SDValue AMDGPUTargetLowering::performTruncateCombine(
   // trunc (srl (bitcast (build_vector x, y))), 16 -> trunc (bitcast y)
   if (Src.getOpcode() == ISD::SRL && !VT.isVector()) {
     if (auto *K = isConstOrConstSplat(Src.getOperand(1))) {
-      if (2 * K->getZExtValue() == Src.getValueType().getScalarSizeInBits()) {
-        SDValue BV = stripBitcast(Src.getOperand(0));
-        if (BV.getOpcode() == ISD::BUILD_VECTOR &&
-            BV.getValueType().getVectorNumElements() == 2) {
-          SDValue SrcElt = BV.getOperand(1);
-          EVT SrcEltVT = SrcElt.getValueType();
-          if (SrcEltVT.isFloatingPoint()) {
-            SrcElt = DAG.getNode(ISD::BITCAST, SL,
-                                 SrcEltVT.changeTypeToInteger(), SrcElt);
+      SDValue BV = stripBitcast(Src.getOperand(0));
+      if (BV.getOpcode() == ISD::BUILD_VECTOR) {
+        EVT SrcEltVT = BV.getOperand(0).getValueType();
+        unsigned SrcEltSize = SrcEltVT.getSizeInBits();
+        unsigned BitIndex = K->getZExtValue();
+        unsigned PartIndex = BitIndex / SrcEltSize;
+
+        if (PartIndex * SrcEltSize == BitIndex &&
+            PartIndex < BV.getNumOperands()) {
+          if (SrcEltVT.getSizeInBits() == VT.getSizeInBits()) {
+            SDValue SrcElt =
+                DAG.getNode(ISD::BITCAST, SL, SrcEltVT.changeTypeToInteger(),
+                            BV.getOperand(PartIndex));
+            return DAG.getNode(ISD::TRUNCATE, SL, VT, SrcElt);
           }
-
-          return DAG.getNode(ISD::TRUNCATE, SL, VT, SrcElt);
         }
       }
     }

diff  --git a/llvm/test/CodeGen/AMDGPU/truncate-lshr-cast-build-vector-combine.ll b/llvm/test/CodeGen/AMDGPU/truncate-lshr-cast-build-vector-combine.ll
new file mode 100644
index 00000000000000..1c3091f6b8d3bf
--- /dev/null
+++ b/llvm/test/CodeGen/AMDGPU/truncate-lshr-cast-build-vector-combine.ll
@@ -0,0 +1,140 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx900 < %s | FileCheck %s
+
+; extract element 0 as shift
+define i32 @cast_v4i32_to_i128_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_trunc_i32:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %bigint = bitcast <4 x i32> %arg to i128
+  %trunc = trunc i128 %bigint to i32
+  ret i32 %trunc
+}
+
+; extract element 1 as shift
+define i32 @cast_v4i32_to_i128_lshr_32_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_32_trunc_i32:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_mov_b32_e32 v0, v1
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %bigint = bitcast <4 x i32> %arg to i128
+  %srl = lshr i128 %bigint, 32
+  %trunc = trunc i128 %srl to i32
+  ret i32 %trunc
+}
+
+; extract element 2 as shift
+define i32 @cast_v4i32_to_i128_lshr_64_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_64_trunc_i32:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_mov_b32_e32 v0, v2
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %bigint = bitcast <4 x i32> %arg to i128
+  %srl = lshr i128 %bigint, 64
+  %trunc = trunc i128 %srl to i32
+  ret i32 %trunc
+}
+
+; extract element 3 as shift
+define i32 @cast_v4i32_to_i128_lshr_96_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_96_trunc_i32:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_mov_b32_e32 v0, v3
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %bigint = bitcast <4 x i32> %arg to i128
+  %srl = lshr i128 %bigint, 96
+  %trunc = trunc i128 %srl to i32
+  ret i32 %trunc
+}
+
+; Shift not aligned to element, not a simple extract
+define i32 @cast_v4i32_to_i128_lshr_33_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_33_trunc_i32:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_alignbit_b32 v0, v2, v1, 1
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %bigint = bitcast <4 x i32> %arg to i128
+  %srl = lshr i128 %bigint, 33
+  %trunc = trunc i128 %srl to i32
+  ret i32 %trunc
+}
+
+; extract misaligned element
+define i32 @cast_v4i32_to_i128_lshr_31_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_31_trunc_i32:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_alignbit_b32 v0, v1, v0, 31
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %bigint = bitcast <4 x i32> %arg to i128
+  %srl = lshr i128 %bigint, 31
+  %trunc = trunc i128 %srl to i32
+  ret i32 %trunc
+}
+
+; extract misaligned element
+define i32 @cast_v4i32_to_i128_lshr_48_trunc_i32(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_48_trunc_i32:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    s_mov_b32 s4, 0x1000706
+; CHECK-NEXT:    v_perm_b32 v0, v1, v2, s4
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %bigint = bitcast <4 x i32> %arg to i128
+  %srl = lshr i128 %bigint, 48
+  %trunc = trunc i128 %srl to i32
+  ret i32 %trunc
+}
+
+; extract elements 1 and 2 with shift
+define i64 @cast_v4i32_to_i128_lshr_32_trunc_i64(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_32_trunc_i64:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_mov_b32_e32 v0, v1
+; CHECK-NEXT:    v_mov_b32_e32 v1, v2
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %bigint = bitcast <4 x i32> %arg to i128
+  %srl = lshr i128 %bigint, 32
+  %trunc = trunc i128 %srl to i64
+  ret i64 %trunc
+}
+
+; extract elements 2 and 3 with shift
+define i64 @cast_v4i32_to_i128_lshr_64_trunc_i64(<4 x i32> %arg) {
+; CHECK-LABEL: cast_v4i32_to_i128_lshr_64_trunc_i64:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_mov_b32_e32 v1, v3
+; CHECK-NEXT:    v_mov_b32_e32 v0, v2
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %bigint = bitcast <4 x i32> %arg to i128
+  %srl = lshr i128 %bigint, 64
+  %trunc = trunc i128 %srl to i64
+  ret i64 %trunc
+}
+
+; FIXME: We don't process this case because we see multiple bitcasts
+; before a 32-bit build_vector
+define i32 @build_vector_i16_to_shift(i16 %arg0, i16 %arg1, i16 %arg2, i16 %arg3) {
+; CHECK-LABEL: build_vector_i16_to_shift:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    s_mov_b32 s4, 0x5040100
+; CHECK-NEXT:    v_perm_b32 v0, v3, v2, s4
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %ins.0 = insertelement <4 x i16> poison, i16 %arg0, i32 0
+  %ins.1 = insertelement <4 x i16> %ins.0, i16 %arg1, i32 1
+  %ins.2 = insertelement <4 x i16> %ins.1, i16 %arg2, i32 2
+  %ins.3 = insertelement <4 x i16> %ins.2, i16 %arg3, i32 3
+
+  %cast = bitcast <4 x i16> %ins.3 to i64
+  %srl = lshr i64 %cast, 32
+  %trunc = trunc i64 %srl to i32
+  ret i32 %trunc
+}


        


More information about the llvm-commits mailing list