[llvm] [VectorCombine] Add Cmp and Select for shuffleToIdentity (PR #92794)

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue May 28 04:28:38 PDT 2024


https://github.com/davemgreen updated https://github.com/llvm/llvm-project/pull/92794

>From 008ad1c8dc994a6e5af90029c2e2c2a0a0510c66 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Thu, 23 May 2024 18:45:42 +0100
Subject: [PATCH 1/2] [VectorCombine] Add Cmp and Select for shuffleToIdentity

Other than some additional checks needed for compare predicates and selects
with scalar condition operands, these are relatively simple additions to what
already exists.
---
 .../Transforms/Vectorize/VectorCombine.cpp    | 19 ++++++++++--
 .../AArch64/shuffletoidentity.ll              | 30 +++----------------
 2 files changed, 21 insertions(+), 28 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 03a69d661acde..c15464b992764 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1742,6 +1742,10 @@ static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
   if (auto *BI = dyn_cast<BinaryOperator>(I))
     return Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(), Ops[0],
                                Ops[1]);
+  if (auto CI = dyn_cast<CmpInst>(I))
+    return Builder.CreateCmp(CI->getPredicate(), Ops[0], Ops[1]);
+  if (auto SI = dyn_cast<SelectInst>(I))
+    return Builder.CreateSelect(Ops[0], Ops[1], Ops[2], "", SI);
   if (II)
     return Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
   assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
@@ -1821,6 +1825,12 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
             return false;
           if (V->getValueID() != FrontV->getValueID())
             return false;
+          if (auto *CI = dyn_cast<CmpInst>(V))
+            if (CI->getPredicate() != cast<CmpInst>(FrontV)->getPredicate())
+              return false;
+          if (auto *SI = dyn_cast<SelectInst>(V))
+            if (!isa<VectorType>(SI->getOperand(0)->getType()))
+              return false;
           if (isa<CallInst>(V) && !isa<IntrinsicInst>(V))
             return false;
           auto *II = dyn_cast<IntrinsicInst>(V);
@@ -1832,12 +1842,17 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
 
     // Check the operator is one that we support. We exclude div/rem in case
     // they hit UB from poison lanes.
-    if (isa<BinaryOperator>(FrontV) &&
-        !cast<BinaryOperator>(FrontV)->isIntDivRem()) {
+    if ((isa<BinaryOperator>(FrontV) &&
+         !cast<BinaryOperator>(FrontV)->isIntDivRem()) ||
+        isa<CmpInst>(FrontV)) {
       Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
       Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
     } else if (isa<UnaryOperator>(FrontV)) {
       Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
+    } else if (isa<SelectInst>(FrontV)) {
+      Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
+      Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
+      Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2));
     } else if (auto *II = dyn_cast<IntrinsicInst>(FrontV);
                II && isTriviallyVectorizable(II->getIntrinsicID())) {
       for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll b/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
index df42777637ad8..5cbda8a1e112e 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
@@ -419,19 +419,8 @@ define <8 x i8> @extrause_shuffle(<8 x i8> %a, <8 x i8> %b) {
 
 define <8 x i8> @icmpsel(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
 ; CHECK-LABEL: @icmpsel(
-; CHECK-NEXT:    [[AB:%.*]] = shufflevector <8 x i8> [[A:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT:    [[AT:%.*]] = shufflevector <8 x i8> [[A]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
-; CHECK-NEXT:    [[BB:%.*]] = shufflevector <8 x i8> [[B:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT:    [[BT:%.*]] = shufflevector <8 x i8> [[B]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
-; CHECK-NEXT:    [[CB:%.*]] = shufflevector <8 x i8> [[C:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT:    [[CT:%.*]] = shufflevector <8 x i8> [[C]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
-; CHECK-NEXT:    [[DB:%.*]] = shufflevector <8 x i8> [[D:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT:    [[DT:%.*]] = shufflevector <8 x i8> [[D]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
-; CHECK-NEXT:    [[ABT1:%.*]] = icmp slt <4 x i8> [[AT]], [[BT]]
-; CHECK-NEXT:    [[ABB1:%.*]] = icmp slt <4 x i8> [[AB]], [[BB]]
-; CHECK-NEXT:    [[ABT:%.*]] = select <4 x i1> [[ABT1]], <4 x i8> [[CT]], <4 x i8> [[DT]]
-; CHECK-NEXT:    [[ABB:%.*]] = select <4 x i1> [[ABB1]], <4 x i8> [[CB]], <4 x i8> [[DB]]
-; CHECK-NEXT:    [[R:%.*]] = shufflevector <4 x i8> [[ABT]], <4 x i8> [[ABB]], <8 x i32> <i32 7, i32 6, i32 5, i32 4, i32 3, i32 2, i32 1, i32 0>
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt <8 x i8> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = select <8 x i1> [[TMP1]], <8 x i8> [[C:%.*]], <8 x i8> [[D:%.*]]
 ; CHECK-NEXT:    ret <8 x i8> [[R]]
 ;
   %ab = shufflevector <8 x i8> %a, <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
@@ -485,19 +474,8 @@ define <8 x i8> @icmpsel_diffentcond(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x
 
 define <8 x i8> @fcmpsel(<8 x half> %a, <8 x half> %b, <8 x i8> %c, <8 x i8> %d) {
 ; CHECK-LABEL: @fcmpsel(
-; CHECK-NEXT:    [[AB:%.*]] = shufflevector <8 x half> [[A:%.*]], <8 x half> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT:    [[AT:%.*]] = shufflevector <8 x half> [[A]], <8 x half> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
-; CHECK-NEXT:    [[BB:%.*]] = shufflevector <8 x half> [[B:%.*]], <8 x half> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT:    [[BT:%.*]] = shufflevector <8 x half> [[B]], <8 x half> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
-; CHECK-NEXT:    [[CB:%.*]] = shufflevector <8 x i8> [[C:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT:    [[CT:%.*]] = shufflevector <8 x i8> [[C]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
-; CHECK-NEXT:    [[DB:%.*]] = shufflevector <8 x i8> [[D:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT:    [[DT:%.*]] = shufflevector <8 x i8> [[D]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
-; CHECK-NEXT:    [[ABT1:%.*]] = fcmp olt <4 x half> [[AT]], [[BT]]
-; CHECK-NEXT:    [[ABB1:%.*]] = fcmp olt <4 x half> [[AB]], [[BB]]
-; CHECK-NEXT:    [[ABT:%.*]] = select <4 x i1> [[ABT1]], <4 x i8> [[CT]], <4 x i8> [[DT]]
-; CHECK-NEXT:    [[ABB:%.*]] = select <4 x i1> [[ABB1]], <4 x i8> [[CB]], <4 x i8> [[DB]]
-; CHECK-NEXT:    [[R:%.*]] = shufflevector <4 x i8> [[ABT]], <4 x i8> [[ABB]], <8 x i32> <i32 7, i32 6, i32 5, i32 4, i32 3, i32 2, i32 1, i32 0>
+; CHECK-NEXT:    [[TMP1:%.*]] = fcmp olt <8 x half> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = select <8 x i1> [[TMP1]], <8 x i8> [[C:%.*]], <8 x i8> [[D:%.*]]
 ; CHECK-NEXT:    ret <8 x i8> [[R]]
 ;
   %ab = shufflevector <8 x half> %a, <8 x half> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>

>From d294768f9a1f8d2e4be6225a62bf647719f0f5a0 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Tue, 28 May 2024 12:27:28 +0100
Subject: [PATCH 2/2] Rebase and update auto*

---
 llvm/lib/Transforms/Vectorize/VectorCombine.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index c15464b992764..056f0d6b3ee6c 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1742,9 +1742,9 @@ static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
   if (auto *BI = dyn_cast<BinaryOperator>(I))
     return Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(), Ops[0],
                                Ops[1]);
-  if (auto CI = dyn_cast<CmpInst>(I))
+  if (auto *CI = dyn_cast<CmpInst>(I))
     return Builder.CreateCmp(CI->getPredicate(), Ops[0], Ops[1]);
-  if (auto SI = dyn_cast<SelectInst>(I))
+  if (auto *SI = dyn_cast<SelectInst>(I))
     return Builder.CreateSelect(Ops[0], Ops[1], Ops[2], "", SI);
   if (II)
     return Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);



More information about the llvm-commits mailing list