[llvm] [VectorCombine] foldShuffleOfBinops - add support for length changing shuffles (PR #88899)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 23 09:02:28 PDT 2024


https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/88899

>From 2bda86ed766d023c1f921b4551c00497fbbf5e4a Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Tue, 16 Apr 2024 15:08:40 +0100
Subject: [PATCH] [VectorCombine] foldShuffleOfBinops - add support for length
 changing shuffles

Refactor to be closer to foldShuffleOfCastops
---
 .../Transforms/Vectorize/VectorCombine.cpp    | 85 +++++++++++++------
 .../VectorCombine/X86/shuffle-of-binops.ll    | 61 +++++++------
 2 files changed, 94 insertions(+), 52 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index da03a69708ddfc..da3c780550a083 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1395,60 +1395,91 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   return true;
 }
 
-/// Try to convert "shuffle (binop), (binop)" with a shared binop operand into
-/// "binop (shuffle), (shuffle)".
+/// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
 bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
-  auto *VecTy = cast<FixedVectorType>(I.getType());
   BinaryOperator *B0, *B1;
-  ArrayRef<int> Mask;
+  ArrayRef<int> OldMask;
   if (!match(&I, m_Shuffle(m_OneUse(m_BinOp(B0)), m_OneUse(m_BinOp(B1)),
-                           m_Mask(Mask))) ||
-      B0->getOpcode() != B1->getOpcode() || B0->getType() != VecTy)
+                           m_Mask(OldMask))))
     return false;
 
   // Don't introduce poison into div/rem.
-  if (any_of(Mask, [](int M) { return M == PoisonMaskElem; }) &&
+  if (any_of(OldMask, [](int M) { return M == PoisonMaskElem; }) &&
       B0->isIntDivRem())
     return false;
 
-  // Try to replace a binop with a shuffle if the shuffle is not costly.
-  // The new shuffle will choose from a single, common operand, so it may be
-  // cheaper than the existing two-operand shuffle.
-  SmallVector<int> UnaryMask = createUnaryMask(Mask, Mask.size());
+  // TODO: Add support for addlike etc.
   Instruction::BinaryOps Opcode = B0->getOpcode();
-  InstructionCost BinopCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
-  InstructionCost ShufCost = TTI.getShuffleCost(
-      TargetTransformInfo::SK_PermuteSingleSrc, VecTy, UnaryMask);
-  if (ShufCost > BinopCost)
+  if (Opcode != B1->getOpcode())
+    return false;
+
+  auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
+  auto *BinOpTy = dyn_cast<FixedVectorType>(B0->getType());
+  if (!ShuffleDstTy || !BinOpTy)
     return false;
 
+  unsigned NumSrcElts = BinOpTy->getNumElements();
+
   // If we have something like "add X, Y" and "add Z, X", swap ops to match.
   Value *X = B0->getOperand(0), *Y = B0->getOperand(1);
   Value *Z = B1->getOperand(0), *W = B1->getOperand(1);
-  if (BinaryOperator::isCommutative(Opcode) && X != Z && Y != W)
+  if (BinaryOperator::isCommutative(Opcode) && X != Z && Y != W &&
+      (X == W || Y == Z))
     std::swap(X, Y);
 
-  Value *Shuf0, *Shuf1;
+  auto ConvertToUnary = [NumSrcElts](int &M) {
+    if (M >= (int)NumSrcElts)
+      M -= NumSrcElts;
+  };
+
+  SmallVector<int> NewMask0(OldMask.begin(), OldMask.end());
+  TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc;
   if (X == Z) {
-    // shuf (bo X, Y), (bo X, W) --> bo (shuf X), (shuf Y, W)
-    Shuf0 = Builder.CreateShuffleVector(X, UnaryMask);
-    Shuf1 = Builder.CreateShuffleVector(Y, W, Mask);
-  } else if (Y == W) {
-    // shuf (bo X, Y), (bo Z, Y) --> bo (shuf X, Z), (shuf Y)
-    Shuf0 = Builder.CreateShuffleVector(X, Z, Mask);
-    Shuf1 = Builder.CreateShuffleVector(Y, UnaryMask);
-  } else {
-    return false;
+    llvm::for_each(NewMask0, ConvertToUnary);
+    SK0 = TargetTransformInfo::SK_PermuteSingleSrc;
+    Z = PoisonValue::get(BinOpTy);
   }
 
+  SmallVector<int> NewMask1(OldMask.begin(), OldMask.end());
+  TargetTransformInfo::ShuffleKind SK1 = TargetTransformInfo::SK_PermuteTwoSrc;
+  if (Y == W) {
+    llvm::for_each(NewMask1, ConvertToUnary);
+    SK1 = TargetTransformInfo::SK_PermuteSingleSrc;
+    W = PoisonValue::get(BinOpTy);
+  }
+
+  // Try to replace a binop with a shuffle if the shuffle is not costly.
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+
+  InstructionCost OldCost =
+      TTI.getArithmeticInstrCost(B0->getOpcode(), BinOpTy, CostKind) +
+      TTI.getArithmeticInstrCost(B1->getOpcode(), BinOpTy, CostKind) +
+      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, BinOpTy,
+                         OldMask, CostKind, 0, nullptr, {B0, B1}, &I);
+
+  InstructionCost NewCost =
+      TTI.getShuffleCost(SK0, BinOpTy, NewMask0, CostKind, 0, nullptr, {X, Z}) +
+      TTI.getShuffleCost(SK1, BinOpTy, NewMask1, CostKind, 0, nullptr, {Y, W}) +
+      TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy, CostKind);
+
+  LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
+                    << "\n  OldCost: " << OldCost << " vs NewCost: " << NewCost
+                    << "\n");
+  if (NewCost >= OldCost)
+    return false;
+
+  Value *Shuf0 = Builder.CreateShuffleVector(X, Z, NewMask0);
+  Value *Shuf1 = Builder.CreateShuffleVector(Y, W, NewMask1);
   Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1);
+
   // Intersect flags from the old binops.
   if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
     NewInst->copyIRFlags(B0);
     NewInst->andIRFlags(B1);
   }
 
-  // TODO: Add Shuf0/Shuf1 to WorkList?
+  Worklist.pushValue(Shuf0);
+  Worklist.pushValue(Shuf1);
   replaceValue(I, *NewBO);
   return true;
 }
diff --git a/llvm/test/Transforms/VectorCombine/X86/shuffle-of-binops.ll b/llvm/test/Transforms/VectorCombine/X86/shuffle-of-binops.ll
index 5c4ad4f1fcc4e5..94a31aa45f3688 100644
--- a/llvm/test/Transforms/VectorCombine/X86/shuffle-of-binops.ll
+++ b/llvm/test/Transforms/VectorCombine/X86/shuffle-of-binops.ll
@@ -25,9 +25,9 @@ define <4 x float> @shuf_fdiv_v4f32_yy(<4 x float> %x, <4 x float> %y, <4 x floa
 define <4 x i32> @shuf_add_v4i32_xx(<4 x i32> %x, <4 x i32> %y, <4 x i32> %z) {
 ; CHECK-LABEL: define <4 x i32> @shuf_add_v4i32_xx(
 ; CHECK-SAME: <4 x i32> [[X:%.*]], <4 x i32> [[Y:%.*]], <4 x i32> [[Z:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[X]], <4 x i32> poison, <4 x i32> <i32 poison, i32 poison, i32 2, i32 0>
-; CHECK-NEXT:    [[R1:%.*]] = shufflevector <4 x i32> [[Y]], <4 x i32> [[Z]], <4 x i32> <i32 poison, i32 poison, i32 6, i32 0>
-; CHECK-NEXT:    [[R2:%.*]] = add <4 x i32> [[TMP1]], [[R1]]
+; CHECK-NEXT:    [[B0:%.*]] = add <4 x i32> [[X]], [[Y]]
+; CHECK-NEXT:    [[B1:%.*]] = add <4 x i32> [[X]], [[Z]]
+; CHECK-NEXT:    [[R2:%.*]] = shufflevector <4 x i32> [[B0]], <4 x i32> [[B1]], <4 x i32> <i32 poison, i32 poison, i32 6, i32 0>
 ; CHECK-NEXT:    ret <4 x i32> [[R2]]
 ;
   %b0 = add <4 x i32> %x, %y
@@ -36,15 +36,22 @@ define <4 x i32> @shuf_add_v4i32_xx(<4 x i32> %x, <4 x i32> %y, <4 x i32> %z) {
   ret <4 x i32> %r
 }
 
-; For commutative instructions, common operand may be swapped.
+; For commutative instructions, common operand may be swapped (SSE - expensive fmul vs AVX - cheap fmul)
 
 define <4 x float> @shuf_fmul_v4f32_xx_swap(<4 x float> %x, <4 x float> %y, <4 x float> %z) {
-; CHECK-LABEL: define <4 x float> @shuf_fmul_v4f32_xx_swap(
-; CHECK-SAME: <4 x float> [[X:%.*]], <4 x float> [[Y:%.*]], <4 x float> [[Z:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[Y]], <4 x float> [[Z]], <4 x i32> <i32 0, i32 3, i32 4, i32 7>
-; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x float> [[X]], <4 x float> poison, <4 x i32> <i32 0, i32 3, i32 0, i32 3>
-; CHECK-NEXT:    [[R:%.*]] = fmul <4 x float> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    ret <4 x float> [[R]]
+; SSE-LABEL: define <4 x float> @shuf_fmul_v4f32_xx_swap(
+; SSE-SAME: <4 x float> [[X:%.*]], <4 x float> [[Y:%.*]], <4 x float> [[Z:%.*]]) #[[ATTR0]] {
+; SSE-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[Y]], <4 x float> [[Z]], <4 x i32> <i32 0, i32 3, i32 4, i32 7>
+; SSE-NEXT:    [[TMP2:%.*]] = shufflevector <4 x float> [[X]], <4 x float> poison, <4 x i32> <i32 0, i32 3, i32 0, i32 3>
+; SSE-NEXT:    [[R:%.*]] = fmul <4 x float> [[TMP1]], [[TMP2]]
+; SSE-NEXT:    ret <4 x float> [[R]]
+;
+; AVX-LABEL: define <4 x float> @shuf_fmul_v4f32_xx_swap(
+; AVX-SAME: <4 x float> [[X:%.*]], <4 x float> [[Y:%.*]], <4 x float> [[Z:%.*]]) #[[ATTR0]] {
+; AVX-NEXT:    [[B0:%.*]] = fmul <4 x float> [[X]], [[Y]]
+; AVX-NEXT:    [[B1:%.*]] = fmul <4 x float> [[Z]], [[X]]
+; AVX-NEXT:    [[R:%.*]] = shufflevector <4 x float> [[B0]], <4 x float> [[B1]], <4 x i32> <i32 0, i32 3, i32 4, i32 7>
+; AVX-NEXT:    ret <4 x float> [[R]]
 ;
   %b0 = fmul <4 x float> %x, %y
   %b1 = fmul <4 x float> %z, %x
@@ -57,9 +64,9 @@ define <4 x float> @shuf_fmul_v4f32_xx_swap(<4 x float> %x, <4 x float> %y, <4 x
 define <2 x i64> @shuf_and_v2i64_yy_swap(<2 x i64> %x, <2 x i64> %y, <2 x i64> %z) {
 ; CHECK-LABEL: define <2 x i64> @shuf_and_v2i64_yy_swap(
 ; CHECK-SAME: <2 x i64> [[X:%.*]], <2 x i64> [[Y:%.*]], <2 x i64> [[Z:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x i64> [[Y]], <2 x i64> poison, <2 x i32> <i32 1, i32 0>
-; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <2 x i64> [[X]], <2 x i64> [[Z]], <2 x i32> <i32 3, i32 0>
-; CHECK-NEXT:    [[R:%.*]] = and <2 x i64> [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    [[B0:%.*]] = and <2 x i64> [[X]], [[Y]]
+; CHECK-NEXT:    [[B1:%.*]] = and <2 x i64> [[Y]], [[Z]]
+; CHECK-NEXT:    [[R:%.*]] = shufflevector <2 x i64> [[B0]], <2 x i64> [[B1]], <2 x i32> <i32 3, i32 0>
 ; CHECK-NEXT:    ret <2 x i64> [[R]]
 ;
   %b0 = and <2 x i64> %x, %y
@@ -84,15 +91,22 @@ define <4 x i32> @shuf_shl_v4i32_xx(<4 x i32> %x, <4 x i32> %y, <4 x i32> %z) {
   ret <4 x i32> %r
 }
 
-; negative test - common operand, but not commutable
+; common operand, but not commutable (SSE - expensive vector shift vs AVX2 - cheap vector shift)
 
 define <4 x i32> @shuf_shl_v4i32_xx_swap(<4 x i32> %x, <4 x i32> %y, <4 x i32> %z) {
-; CHECK-LABEL: define <4 x i32> @shuf_shl_v4i32_xx_swap(
-; CHECK-SAME: <4 x i32> [[X:%.*]], <4 x i32> [[Y:%.*]], <4 x i32> [[Z:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[B0:%.*]] = shl <4 x i32> [[X]], [[Y]]
-; CHECK-NEXT:    [[B1:%.*]] = shl <4 x i32> [[Z]], [[X]]
-; CHECK-NEXT:    [[R1:%.*]] = shufflevector <4 x i32> [[B0]], <4 x i32> [[B1]], <4 x i32> <i32 3, i32 2, i32 2, i32 5>
-; CHECK-NEXT:    ret <4 x i32> [[R1]]
+; SSE-LABEL: define <4 x i32> @shuf_shl_v4i32_xx_swap(
+; SSE-SAME: <4 x i32> [[X:%.*]], <4 x i32> [[Y:%.*]], <4 x i32> [[Z:%.*]]) #[[ATTR0]] {
+; SSE-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[X]], <4 x i32> [[Z]], <4 x i32> <i32 3, i32 2, i32 2, i32 5>
+; SSE-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i32> [[Y]], <4 x i32> [[X]], <4 x i32> <i32 3, i32 2, i32 2, i32 5>
+; SSE-NEXT:    [[R:%.*]] = shl <4 x i32> [[TMP1]], [[TMP2]]
+; SSE-NEXT:    ret <4 x i32> [[R]]
+;
+; AVX-LABEL: define <4 x i32> @shuf_shl_v4i32_xx_swap(
+; AVX-SAME: <4 x i32> [[X:%.*]], <4 x i32> [[Y:%.*]], <4 x i32> [[Z:%.*]]) #[[ATTR0]] {
+; AVX-NEXT:    [[B0:%.*]] = shl <4 x i32> [[X]], [[Y]]
+; AVX-NEXT:    [[B1:%.*]] = shl <4 x i32> [[Z]], [[X]]
+; AVX-NEXT:    [[R:%.*]] = shufflevector <4 x i32> [[B0]], <4 x i32> [[B1]], <4 x i32> <i32 3, i32 2, i32 2, i32 5>
+; AVX-NEXT:    ret <4 x i32> [[R]]
 ;
   %b0 = shl <4 x i32> %x, %y
   %b1 = shl <4 x i32> %z, %x
@@ -116,7 +130,7 @@ define <2 x i64> @shuf_sub_add_v2i64_yy(<2 x i64> %x, <2 x i64> %y, <2 x i64> %z
   ret <2 x i64> %r
 }
 
-; negative test - type change via shuffle
+; type change via shuffle
 
 define <8 x float> @shuf_fmul_v4f32_xx_type(<4 x float> %x, <4 x float> %y, <4 x float> %z) {
 ; CHECK-LABEL: define <8 x float> @shuf_fmul_v4f32_xx_type(
@@ -168,7 +182,7 @@ define <4 x i32> @shuf_mul_v4i32_yy_use2(<4 x i32> %x, <4 x i32> %y, <4 x i32> %
   ret <4 x i32> %r
 }
 
-; negative test - must have matching operand
+; must have matching operand
 
 define <4 x float> @shuf_fadd_v4f32_no_common_op(<4 x float> %x, <4 x float> %y, <4 x float> %z, <4 x float> %w) {
 ; CHECK-LABEL: define <4 x float> @shuf_fadd_v4f32_no_common_op(
@@ -216,6 +230,3 @@ define <4 x i32> @shuf_srem_v4i32_poison(<4 x i32> %a0, <4 x i32> %a1) {
   ret <4 x i32> %r
 }
 
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; AVX: {{.*}}
-; SSE: {{.*}}



More information about the llvm-commits mailing list