[llvm] c36b7e2 - [InstCombine] enhance vector bitwise select matching

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 9 05:55:10 PST 2021


Author: Sanjay Patel
Date: 2021-11-09T08:54:59-05:00
New Revision: c36b7e21bd8f04a44d6935c3469b1bcbbafeeb2d

URL: https://github.com/llvm/llvm-project/commit/c36b7e21bd8f04a44d6935c3469b1bcbbafeeb2d
DIFF: https://github.com/llvm/llvm-project/commit/c36b7e21bd8f04a44d6935c3469b1bcbbafeeb2d.diff

LOG: [InstCombine] enhance vector bitwise select matching

(Cond & C) | (~bitcast(Cond) & D) --> bitcast (select Cond, (bc C), (bc D))

This is part of fixing:
https://llvm.org/PR34047

That report shows a case where a bitcast is sitting between the select condition
candidate and its 'not' value due to current cast canonicalization rules.

There's a bitcast type restriction that might be violated in existing matching,
but I still need to investigate if that is possible -
Alive2 shows we can only do this transform safely when the bitcast is from
narrow to wide vector elements (otherwise poison could leak into elements
that were safe in the original code):
https://alive2.llvm.org/ce/z/Hf66qh

Differential Revision: https://reviews.llvm.org/D113035

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
    llvm/test/Transforms/InstCombine/logical-select.ll
    llvm/test/Transforms/PhaseOrdering/X86/vector-math.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 7e40e358b6a22..ebd2e3ebc60d3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2298,22 +2298,30 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) {
   if (!Ty->isIntOrIntVectorTy() || !B->getType()->isIntOrIntVectorTy())
     return nullptr;
 
-  // We need 0 or all-1's bitmasks.
-  if (ComputeNumSignBits(A) != Ty->getScalarSizeInBits())
-    return nullptr;
-
-  // If B is the 'not' value of A, we have our answer.
+  // If A is the 'not' operand of B and has enough signbits, we have our answer.
   if (match(B, m_Not(m_Specific(A)))) {
     // If these are scalars or vectors of i1, A can be used directly.
     if (Ty->isIntOrIntVectorTy(1))
       return A;
-    return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(Ty));
+
+    // If we look through a vector bitcast, the caller will bitcast the operands
+    // to match the condition's number of bits (N x i1).
+    // To make this poison-safe, disallow bitcast from wide element to narrow
+    // element. That could allow poison in lanes where it was not present in the
+    // original code.
+    A = peekThroughBitcast(A);
+    unsigned NumSignBits = ComputeNumSignBits(A);
+    if (NumSignBits == A->getType()->getScalarSizeInBits() &&
+        NumSignBits <= Ty->getScalarSizeInBits())
+      return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(A->getType()));
+    return nullptr;
   }
 
   // If both operands are constants, see if the constants are inverse bitmasks.
   Constant *AConst, *BConst;
   if (match(A, m_Constant(AConst)) && match(B, m_Constant(BConst)))
-    if (AConst == ConstantExpr::getNot(BConst))
+    if (AConst == ConstantExpr::getNot(BConst) &&
+        ComputeNumSignBits(A) == Ty->getScalarSizeInBits())
       return Builder.CreateZExtOrTrunc(A, CmpInst::makeCmpResultType(Ty));
 
   // Look for more complex patterns. The 'not' op may be hidden behind various
@@ -2357,10 +2365,17 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B,
   B = peekThroughBitcast(B, true);
   if (Value *Cond = getSelectCondition(A, B)) {
     // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D))
+    // If this is a vector, we may need to cast to match the condition's length.
     // The bitcasts will either all exist or all not exist. The builder will
     // not create unnecessary casts if the types already match.
-    Value *BitcastC = Builder.CreateBitCast(C, A->getType());
-    Value *BitcastD = Builder.CreateBitCast(D, A->getType());
+    Type *SelTy = A->getType();
+    if (auto *VecTy = dyn_cast<VectorType>(Cond->getType())) {
+      unsigned Elts = VecTy->getElementCount().getKnownMinValue();
+      Type *EltTy = Builder.getIntNTy(SelTy->getPrimitiveSizeInBits() / Elts);
+      SelTy = VectorType::get(EltTy, VecTy->getElementCount());
+    }
+    Value *BitcastC = Builder.CreateBitCast(C, SelTy);
+    Value *BitcastD = Builder.CreateBitCast(D, SelTy);
     Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD);
     return Builder.CreateBitCast(Select, OrigType);
   }

diff  --git a/llvm/test/Transforms/InstCombine/logical-select.ll b/llvm/test/Transforms/InstCombine/logical-select.ll
index 610eb20eaf3e2..3e3cc11e6591b 100644
--- a/llvm/test/Transforms/InstCombine/logical-select.ll
+++ b/llvm/test/Transforms/InstCombine/logical-select.ll
@@ -682,15 +682,15 @@ define <4 x i32> @computesignbits_through_two_input_shuffle(<4 x i32> %x, <4 x i
   ret <4 x i32> %sel
 }
 
+; Bitcast of condition from narrow source element type can be converted to select.
+
 define <2 x i64> @bitcast_vec_cond(<16 x i1> %cond, <2 x i64> %c, <2 x i64> %d) {
 ; CHECK-LABEL: @bitcast_vec_cond(
-; CHECK-NEXT:    [[S:%.*]] = sext <16 x i1> [[COND:%.*]] to <16 x i8>
-; CHECK-NEXT:    [[T9:%.*]] = bitcast <16 x i8> [[S]] to <2 x i64>
-; CHECK-NEXT:    [[NOTT9:%.*]] = xor <2 x i64> [[T9]], <i64 -1, i64 -1>
-; CHECK-NEXT:    [[T11:%.*]] = and <2 x i64> [[NOTT9]], [[C:%.*]]
-; CHECK-NEXT:    [[T12:%.*]] = and <2 x i64> [[T9]], [[D:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = or <2 x i64> [[T11]], [[T12]]
-; CHECK-NEXT:    ret <2 x i64> [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i64> [[D:%.*]] to <16 x i8>
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <2 x i64> [[C:%.*]] to <16 x i8>
+; CHECK-NEXT:    [[TMP3:%.*]] = select <16 x i1> [[COND:%.*]], <16 x i8> [[TMP1]], <16 x i8> [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <16 x i8> [[TMP3]] to <2 x i64>
+; CHECK-NEXT:    ret <2 x i64> [[TMP4]]
 ;
   %s = sext <16 x i1> %cond to <16 x i8>
   %t9 = bitcast <16 x i8> %s to <2 x i64>
@@ -701,6 +701,8 @@ define <2 x i64> @bitcast_vec_cond(<16 x i1> %cond, <2 x i64> %c, <2 x i64> %d)
   ret <2 x i64> %r
 }
 
+; Negative test - bitcast of condition from wide source element type cannot be converted to select.
+
 define <8 x i3> @bitcast_vec_cond_commute1(<3 x i1> %cond, <8 x i3> %pc, <8 x i3> %d) {
 ; CHECK-LABEL: @bitcast_vec_cond_commute1(
 ; CHECK-NEXT:    [[C:%.*]] = mul <8 x i3> [[PC:%.*]], [[PC]]
@@ -726,13 +728,11 @@ define <2 x i16> @bitcast_vec_cond_commute2(<4 x i1> %cond, <2 x i16> %pc, <2 x
 ; CHECK-LABEL: @bitcast_vec_cond_commute2(
 ; CHECK-NEXT:    [[C:%.*]] = mul <2 x i16> [[PC:%.*]], [[PC]]
 ; CHECK-NEXT:    [[D:%.*]] = mul <2 x i16> [[PD:%.*]], [[PD]]
-; CHECK-NEXT:    [[S:%.*]] = sext <4 x i1> [[COND:%.*]] to <4 x i8>
-; CHECK-NEXT:    [[T9:%.*]] = bitcast <4 x i8> [[S]] to <2 x i16>
-; CHECK-NEXT:    [[NOTT9:%.*]] = xor <2 x i16> [[T9]], <i16 -1, i16 -1>
-; CHECK-NEXT:    [[T11:%.*]] = and <2 x i16> [[C]], [[NOTT9]]
-; CHECK-NEXT:    [[T12:%.*]] = and <2 x i16> [[D]], [[T9]]
-; CHECK-NEXT:    [[R:%.*]] = or <2 x i16> [[T11]], [[T12]]
-; CHECK-NEXT:    ret <2 x i16> [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i16> [[D]] to <4 x i8>
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <2 x i16> [[C]] to <4 x i8>
+; CHECK-NEXT:    [[TMP3:%.*]] = select <4 x i1> [[COND:%.*]], <4 x i8> [[TMP1]], <4 x i8> [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <4 x i8> [[TMP3]] to <2 x i16>
+; CHECK-NEXT:    ret <2 x i16> [[TMP4]]
 ;
   %c = mul <2 x i16> %pc, %pc ; thwart complexity-based canonicalization
   %d = mul <2 x i16> %pd, %pd ; thwart complexity-based canonicalization
@@ -745,17 +745,18 @@ define <2 x i16> @bitcast_vec_cond_commute2(<4 x i1> %cond, <2 x i16> %pc, <2 x
   ret <2 x i16> %r
 }
 
+; Condition doesn't have to be a bool vec - just all signbits.
+
 define <2 x i16> @bitcast_vec_cond_commute3(<4 x i8> %cond, <2 x i16> %pc, <2 x i16> %pd) {
 ; CHECK-LABEL: @bitcast_vec_cond_commute3(
 ; CHECK-NEXT:    [[C:%.*]] = mul <2 x i16> [[PC:%.*]], [[PC]]
 ; CHECK-NEXT:    [[D:%.*]] = mul <2 x i16> [[PD:%.*]], [[PD]]
-; CHECK-NEXT:    [[S:%.*]] = ashr <4 x i8> [[COND:%.*]], <i8 7, i8 7, i8 7, i8 7>
-; CHECK-NEXT:    [[T9:%.*]] = bitcast <4 x i8> [[S]] to <2 x i16>
-; CHECK-NEXT:    [[NOTT9:%.*]] = xor <2 x i16> [[T9]], <i16 -1, i16 -1>
-; CHECK-NEXT:    [[T11:%.*]] = and <2 x i16> [[C]], [[NOTT9]]
-; CHECK-NEXT:    [[T12:%.*]] = and <2 x i16> [[D]], [[T9]]
-; CHECK-NEXT:    [[R:%.*]] = or <2 x i16> [[T11]], [[T12]]
-; CHECK-NEXT:    ret <2 x i16> [[R]]
+; CHECK-NEXT:    [[DOTNOT:%.*]] = icmp sgt <4 x i8> [[COND:%.*]], <i8 -1, i8 -1, i8 -1, i8 -1>
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i16> [[D]] to <4 x i8>
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <2 x i16> [[C]] to <4 x i8>
+; CHECK-NEXT:    [[TMP3:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i8> [[TMP2]], <4 x i8> [[TMP1]]
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <4 x i8> [[TMP3]] to <2 x i16>
+; CHECK-NEXT:    ret <2 x i16> [[TMP4]]
 ;
   %c = mul <2 x i16> %pc, %pc ; thwart complexity-based canonicalization
   %d = mul <2 x i16> %pd, %pd ; thwart complexity-based canonicalization

diff  --git a/llvm/test/Transforms/PhaseOrdering/X86/vector-math.ll b/llvm/test/Transforms/PhaseOrdering/X86/vector-math.ll
index a658b19898896..964307b5b40b6 100644
--- a/llvm/test/Transforms/PhaseOrdering/X86/vector-math.ll
+++ b/llvm/test/Transforms/PhaseOrdering/X86/vector-math.ll
@@ -68,15 +68,9 @@ define internal <2 x i64> @_mm_set_epi32(i32 %__i3, i32 %__i2, i32 %__i1, i32 %_
 define <2 x i64> @abs_v4i32(<2 x i64> %x) {
 ; CHECK-LABEL: @abs_v4i32(
 ; CHECK-NEXT:    [[T1_I:%.*]] = bitcast <2 x i64> [[X:%.*]] to <4 x i32>
-; CHECK-NEXT:    [[SUB_I:%.*]] = sub <4 x i32> zeroinitializer, [[T1_I]]
-; CHECK-NEXT:    [[T1_I_LOBIT:%.*]] = ashr <4 x i32> [[T1_I]], <i32 31, i32 31, i32 31, i32 31>
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i32> [[T1_I_LOBIT]] to <2 x i64>
-; CHECK-NEXT:    [[T2_I_I:%.*]] = xor <2 x i64> [[TMP1]], <i64 -1, i64 -1>
-; CHECK-NEXT:    [[AND_I_I1:%.*]] = and <4 x i32> [[T1_I_LOBIT]], [[SUB_I]]
-; CHECK-NEXT:    [[AND_I_I:%.*]] = bitcast <4 x i32> [[AND_I_I1]] to <2 x i64>
-; CHECK-NEXT:    [[AND_I1_I:%.*]] = and <2 x i64> [[T2_I_I]], [[X]]
-; CHECK-NEXT:    [[OR_I_I:%.*]] = or <2 x i64> [[AND_I1_I]], [[AND_I_I]]
-; CHECK-NEXT:    ret <2 x i64> [[OR_I_I]]
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[T1_I]], i1 false)
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <4 x i32> [[TMP1]] to <2 x i64>
+; CHECK-NEXT:    ret <2 x i64> [[TMP2]]
 ;
   %call = call <2 x i64> @_mm_set1_epi32(i32 -1)
   %call1 = call <2 x i64> @_mm_setzero_si128()
@@ -90,13 +84,9 @@ define <2 x i64> @max_v4i32(<2 x i64> %x, <2 x i64> %y) {
 ; CHECK-NEXT:    [[T0_I_I:%.*]] = bitcast <2 x i64> [[X:%.*]] to <4 x i32>
 ; CHECK-NEXT:    [[T1_I_I:%.*]] = bitcast <2 x i64> [[Y:%.*]] to <4 x i32>
 ; CHECK-NEXT:    [[CMP_I_I:%.*]] = icmp sgt <4 x i32> [[T0_I_I]], [[T1_I_I]]
-; CHECK-NEXT:    [[SEXT_I_I:%.*]] = sext <4 x i1> [[CMP_I_I]] to <4 x i32>
-; CHECK-NEXT:    [[T2_I_I:%.*]] = bitcast <4 x i32> [[SEXT_I_I]] to <2 x i64>
-; CHECK-NEXT:    [[NEG_I_I:%.*]] = xor <2 x i64> [[T2_I_I]], <i64 -1, i64 -1>
-; CHECK-NEXT:    [[AND_I_I:%.*]] = and <2 x i64> [[NEG_I_I]], [[Y]]
-; CHECK-NEXT:    [[AND_I1_I:%.*]] = and <2 x i64> [[T2_I_I]], [[X]]
-; CHECK-NEXT:    [[OR_I_I:%.*]] = or <2 x i64> [[AND_I1_I]], [[AND_I_I]]
-; CHECK-NEXT:    ret <2 x i64> [[OR_I_I]]
+; CHECK-NEXT:    [[TMP1:%.*]] = select <4 x i1> [[CMP_I_I]], <4 x i32> [[T0_I_I]], <4 x i32> [[T1_I_I]]
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <4 x i32> [[TMP1]] to <2 x i64>
+; CHECK-NEXT:    ret <2 x i64> [[TMP2]]
 ;
   %call = call <2 x i64> @cmpgt_i32_sel_m128i(<2 x i64> %x, <2 x i64> %y, <2 x i64> %y, <2 x i64> %x)
   ret <2 x i64> %call


        


More information about the llvm-commits mailing list