[llvm] [X86][AVX] Fix handling of out-of-bounds shift amounts in AVX2 vector shift nodes (PR #84426)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 11 09:04:26 PDT 2024


https://github.com/SahilPatidar updated https://github.com/llvm/llvm-project/pull/84426

>From f48f6a56bd88dbc5c105c62be99b980ee2b416c1 Mon Sep 17 00:00:00 2001
From: SahilPatidar <patidarsahil2001 at gmail.com>
Date: Thu, 7 Mar 2024 21:26:31 +0530
Subject: [PATCH] [X86][AVX] Fix handling of out-of-bounds shift amounts in
 AVX2 vector shift nodes #83840

---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 14 ++++++++++++
 llvm/test/CodeGen/X86/combine-sra.ll    | 30 +++++++++++++++++++++++++
 2 files changed, 44 insertions(+)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6eaaec407dbb08..ea15702c1da439 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -28927,6 +28927,9 @@ SDValue X86TargetLowering::LowerWin64_INT128_TO_FP(SDValue Op,
 // supported by the Subtarget
 static bool supportedVectorShiftWithImm(EVT VT, const X86Subtarget &Subtarget,
                                         unsigned Opcode) {
+  assert(Opcode == ISD::SHL || Opcode == ISD::SRA || Opcode == ISD::SRL &&
+        "Unexpected Opcode!");
+
   if (!VT.isSimple())
     return false;
 
@@ -47291,6 +47294,17 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
   if (SDValue V = combineShiftToPMULH(N, DAG, Subtarget))
     return V;
 
+  APInt ShiftAmt;
+  SDNode *UMinNode = N1.getNode();
+  if (supportedVectorVarShift(VT, Subtarget, ISD::SRA) &&
+      UMinNode->getOpcode() == ISD::UMIN &&
+      ISD::isConstantSplatVector(UMinNode->getOperand(1).getNode(), ShiftAmt) &&
+      ShiftAmt == VT.getScalarSizeInBits() - 1) {
+    SDValue ShrAmtVal = UMinNode->getOperand(0);
+    SDLoc DL(N);
+    return DAG.getNode(X86ISD::VSRAV, DL, N->getVTList(), N0, ShrAmtVal);
+  }
+
   // fold (ashr (shl, a, [56,48,32,24,16]), SarConst)
   // into (shl, (sext (a), [56,48,32,24,16] - SarConst)) or
   // into (lshr, (sext (a), SarConst - [56,48,32,24,16]))
diff --git a/llvm/test/CodeGen/X86/combine-sra.ll b/llvm/test/CodeGen/X86/combine-sra.ll
index cc0ed2b8268c66..462e4de211260c 100644
--- a/llvm/test/CodeGen/X86/combine-sra.ll
+++ b/llvm/test/CodeGen/X86/combine-sra.ll
@@ -382,3 +382,33 @@ define <4 x i32> @combine_vec_ashr_positive_splat(<4 x i32> %x, <4 x i32> %y) {
   %2 = ashr <4 x i32> %1, <i32 10, i32 10, i32 10, i32 10>
   ret <4 x i32> %2
 }
+
+define <4 x i32> @combine_vec_ashr_out_of_bound(<4 x i32> %x, <4 x i32> %y) {
+; SSE-LABEL: combine_vec_ashr_out_of_bound:
+; SSE:       # %bb.0:
+; SSE-NEXT:    pminud {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
+; SSE-NEXT:    pshuflw {{.*#+}} xmm2 = xmm1[2,3,3,3,4,5,6,7]
+; SSE-NEXT:    movdqa %xmm0, %xmm3
+; SSE-NEXT:    psrad %xmm2, %xmm3
+; SSE-NEXT:    pshufd {{.*#+}} xmm2 = xmm1[2,3,2,3]
+; SSE-NEXT:    pshuflw {{.*#+}} xmm4 = xmm2[2,3,3,3,4,5,6,7]
+; SSE-NEXT:    movdqa %xmm0, %xmm5
+; SSE-NEXT:    psrad %xmm4, %xmm5
+; SSE-NEXT:    pblendw {{.*#+}} xmm5 = xmm3[0,1,2,3],xmm5[4,5,6,7]
+; SSE-NEXT:    pshuflw {{.*#+}} xmm1 = xmm1[0,1,1,1,4,5,6,7]
+; SSE-NEXT:    movdqa %xmm0, %xmm3
+; SSE-NEXT:    psrad %xmm1, %xmm3
+; SSE-NEXT:    pshuflw {{.*#+}} xmm1 = xmm2[0,1,1,1,4,5,6,7]
+; SSE-NEXT:    psrad %xmm1, %xmm0
+; SSE-NEXT:    pblendw {{.*#+}} xmm0 = xmm3[0,1,2,3],xmm0[4,5,6,7]
+; SSE-NEXT:    pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm5[2,3],xmm0[4,5],xmm5[6,7]
+; SSE-NEXT:    retq
+;
+; AVX-LABEL: combine_vec_ashr_out_of_bound:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpsravd %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    retq
+  %1 = tail call <4 x i32> @llvm.umin.v4i32(<4 x i32> %y, <4 x i32> <i32 31, i32 31, i32 31, i32 31>)
+  %2 = ashr <4 x i32> %x, %1
+  ret <4 x i32> %2
+}



More information about the llvm-commits mailing list