[llvm] 0236c57 - [InstCombine] try to fold one-demanded-bit-of-multiply

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 4 08:42:38 PST 2022


Author: Sanjay Patel
Date: 2022-02-04T11:40:54-05:00
New Revision: 0236c571810dea1b72d99ee0af124a7a09976e00

URL: https://github.com/llvm/llvm-project/commit/0236c571810dea1b72d99ee0af124a7a09976e00
DIFF: https://github.com/llvm/llvm-project/commit/0236c571810dea1b72d99ee0af124a7a09976e00.diff

LOG: [InstCombine] try to fold one-demanded-bit-of-multiply

This is a generalization of the icmp fold in D118061 (and that can be abandoned).
We're looking for a disguised form of "odd * odd must be odd".
Some Alive2 proofs to show correctness:
https://alive2.llvm.org/ce/z/60Y8hz
https://alive2.llvm.org/ce/z/HfAP6R

Differential Revision: https://reviews.llvm.org/D118539

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
    llvm/test/Transforms/InstCombine/icmp-mul-and.ll
    llvm/test/Transforms/InstCombine/mul-masked-bits.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 3f064cfda712..baa7a74f481d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -544,6 +544,23 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
                                         NSW, LHSKnown, RHSKnown);
     break;
   }
+  case Instruction::Mul: {
+    // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1.
+    // If we demand exactly one bit N and we have "X * (C' << N)" where C' is
+    // odd (has LSB set), then the left-shifted low bit of X is the answer.
+    if (DemandedMask.isPowerOf2()) {
+      unsigned CTZ = DemandedMask.countTrailingZeros();
+      const APInt *C;
+      if (match(I->getOperand(1), m_APInt(C)) &&
+          C->countTrailingZeros() == CTZ) {
+        Constant *ShiftC = ConstantInt::get(I->getType(), CTZ);
+        Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC);
+        return InsertNewInstWith(Shl, *I);
+      }
+    }
+    computeKnownBits(I, Known, Depth, CxtI);
+    break;
+  }
   case Instruction::Shl: {
     const APInt *SA;
     if (match(I->getOperand(1), m_APInt(SA))) {

diff  --git a/llvm/test/Transforms/InstCombine/icmp-mul-and.ll b/llvm/test/Transforms/InstCombine/icmp-mul-and.ll
index 2d63bfac0ffc..bcce37de7687 100644
--- a/llvm/test/Transforms/InstCombine/icmp-mul-and.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-mul-and.ll
@@ -5,9 +5,8 @@ declare void @use(i8)
 
 define i1 @mul_mask_pow2_eq0(i8 %x) {
 ; CHECK-LABEL: @mul_mask_pow2_eq0(
-; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X:%.*]], 44
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[MUL]], 4
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[AND]], 0
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X:%.*]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %mul = mul i8 %x, 44
@@ -16,6 +15,9 @@ define i1 @mul_mask_pow2_eq0(i8 %x) {
   ret i1 %cmp
 }
 
+; TODO: Demanded bits does not convert the mul to shift,
+; but the 'and' could be of 'x' directly.
+
 define i1 @mul_mask_pow2_ne0_use1(i8 %x) {
 ; CHECK-LABEL: @mul_mask_pow2_ne0_use1(
 ; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X:%.*]], 40
@@ -31,10 +33,12 @@ define i1 @mul_mask_pow2_ne0_use1(i8 %x) {
   ret i1 %cmp
 }
 
+; negative test - extra use of 'and' would require more instructions
+
 define i1 @mul_mask_pow2_ne0_use2(i8 %x) {
 ; CHECK-LABEL: @mul_mask_pow2_ne0_use2(
-; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X:%.*]], 40
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[MUL]], 8
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i8 [[X:%.*]], 3
+; CHECK-NEXT:    [[AND:%.*]] = and i8 [[TMP1]], 8
 ; CHECK-NEXT:    call void @use(i8 [[AND]])
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i8 [[AND]], 0
 ; CHECK-NEXT:    ret i1 [[CMP]]
@@ -46,11 +50,12 @@ define i1 @mul_mask_pow2_ne0_use2(i8 %x) {
   ret i1 %cmp
 }
 
+; non-equality predicates are converted to equality
+
 define i1 @mul_mask_pow2_sgt0(i8 %x) {
 ; CHECK-LABEL: @mul_mask_pow2_sgt0(
-; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X:%.*]], 44
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[MUL]], 4
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i8 [[AND]], 0
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X:%.*]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i8 [[TMP1]], 0
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %mul = mul i8 %x, 44
@@ -59,11 +64,12 @@ define i1 @mul_mask_pow2_sgt0(i8 %x) {
   ret i1 %cmp
 }
 
+; unnecessary mask bits are removed
+
 define i1 @mul_mask_fakepow2_ne0(i8 %x) {
 ; CHECK-LABEL: @mul_mask_fakepow2_ne0(
-; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X:%.*]], 44
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[MUL]], 4
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i8 [[AND]], 0
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X:%.*]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i8 [[TMP1]], 0
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %mul = mul i8 %x, 44
@@ -72,11 +78,12 @@ define i1 @mul_mask_fakepow2_ne0(i8 %x) {
   ret i1 %cmp
 }
 
+; non-zero cmp constant is converted
+
 define i1 @mul_mask_pow2_eq4(i8 %x) {
 ; CHECK-LABEL: @mul_mask_pow2_eq4(
-; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X:%.*]], 44
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[MUL]], 4
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i8 [[AND]], 0
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X:%.*]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i8 [[TMP1]], 0
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %mul = mul i8 %x, 44
@@ -85,6 +92,8 @@ define i1 @mul_mask_pow2_eq4(i8 %x) {
   ret i1 %cmp
 }
 
+; negative test - must be pow2 mask constant
+
 define i1 @mul_mask_notpow2_ne(i8 %x) {
 ; CHECK-LABEL: @mul_mask_notpow2_ne(
 ; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X:%.*]], 60
@@ -101,9 +110,8 @@ define i1 @mul_mask_notpow2_ne(i8 %x) {
 define i1 @pr40493(i32 %area) {
 ; CHECK-LABEL: @pr40493(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[AREA:%.*]], 12
-; CHECK-NEXT:    [[REM:%.*]] = and i32 [[MUL]], 4
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[REM]], 0
+; CHECK-NEXT:    [[TMP0:%.*]] = and i32 [[AREA:%.*]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[TMP0]], 0
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:
@@ -146,8 +154,8 @@ entry:
 define i32 @pr40493_neg3(i32 %area) {
 ; CHECK-LABEL: @pr40493_neg3(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[AREA:%.*]], 12
-; CHECK-NEXT:    [[REM:%.*]] = and i32 [[MUL]], 4
+; CHECK-NEXT:    [[TMP0:%.*]] = shl i32 [[AREA:%.*]], 2
+; CHECK-NEXT:    [[REM:%.*]] = and i32 [[TMP0]], 4
 ; CHECK-NEXT:    ret i32 [[REM]]
 ;
 entry:
@@ -159,9 +167,8 @@ entry:
 define <4 x i1> @pr40493_vec1(<4 x i32> %area) {
 ; CHECK-LABEL: @pr40493_vec1(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MUL:%.*]] = mul <4 x i32> [[AREA:%.*]], <i32 12, i32 12, i32 12, i32 12>
-; CHECK-NEXT:    [[REM:%.*]] = and <4 x i32> [[MUL]], <i32 4, i32 4, i32 4, i32 4>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <4 x i32> [[REM]], zeroinitializer
+; CHECK-NEXT:    [[TMP0:%.*]] = and <4 x i32> [[AREA:%.*]], <i32 1, i32 1, i32 1, i32 1>
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <4 x i32> [[TMP0]], zeroinitializer
 ; CHECK-NEXT:    ret <4 x i1> [[CMP]]
 ;
 entry:

diff  --git a/llvm/test/Transforms/InstCombine/mul-masked-bits.ll b/llvm/test/Transforms/InstCombine/mul-masked-bits.ll
index 6bd4eb8463d9..d872c7c9c868 100644
--- a/llvm/test/Transforms/InstCombine/mul-masked-bits.ll
+++ b/llvm/test/Transforms/InstCombine/mul-masked-bits.ll
@@ -61,8 +61,8 @@ define <4 x i1> @PR48683_vec_undef(<4 x i32> %x) {
 
 define i8 @one_demanded_bit(i8 %x) {
 ; CHECK-LABEL: @one_demanded_bit(
-; CHECK-NEXT:    [[M:%.*]] = mul i8 [[X:%.*]], -64
-; CHECK-NEXT:    [[R:%.*]] = or i8 [[M]], -65
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i8 [[X:%.*]], 6
+; CHECK-NEXT:    [[R:%.*]] = or i8 [[TMP1]], -65
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %m = mul i8 %x, 192  ; 0b1100_0000
@@ -72,8 +72,8 @@ define i8 @one_demanded_bit(i8 %x) {
 
 define <2 x i8> @one_demanded_bit_splat(<2 x i8> %x) {
 ; CHECK-LABEL: @one_demanded_bit_splat(
-; CHECK-NEXT:    [[M:%.*]] = mul <2 x i8> [[X:%.*]], <i8 -96, i8 -96>
-; CHECK-NEXT:    [[R:%.*]] = and <2 x i8> [[M]], <i8 32, i8 32>
+; CHECK-NEXT:    [[TMP1:%.*]] = shl <2 x i8> [[X:%.*]], <i8 5, i8 5>
+; CHECK-NEXT:    [[R:%.*]] = and <2 x i8> [[TMP1]], <i8 32, i8 32>
 ; CHECK-NEXT:    ret <2 x i8> [[R]]
 ;
   %m = mul <2 x i8> %x, <i8 160, i8 160> ; 0b1010_0000
@@ -83,8 +83,7 @@ define <2 x i8> @one_demanded_bit_splat(<2 x i8> %x) {
 
 define i67 @one_demanded_low_bit(i67 %x) {
 ; CHECK-LABEL: @one_demanded_low_bit(
-; CHECK-NEXT:    [[M:%.*]] = mul i67 [[X:%.*]], -63
-; CHECK-NEXT:    [[R:%.*]] = and i67 [[M]], 1
+; CHECK-NEXT:    [[R:%.*]] = and i67 [[X:%.*]], 1
 ; CHECK-NEXT:    ret i67 [[R]]
 ;
   %m = mul i67 %x, -63 ; any odd number will do


        


More information about the llvm-commits mailing list