[llvm] 513251b - [InstCombine] Improve transforms for `(mul X, Y)` -> `(shl X, log2(Y)`

Noah Goldstein via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 7 12:58:35 PDT 2023


Author: Noah Goldstein
Date: 2023-04-07T14:58:20-05:00
New Revision: 513251b76582cf0d3ba81dfe36a05d452001b3b3

URL: https://github.com/llvm/llvm-project/commit/513251b76582cf0d3ba81dfe36a05d452001b3b3
DIFF: https://github.com/llvm/llvm-project/commit/513251b76582cf0d3ba81dfe36a05d452001b3b3.diff

LOG: [InstCombine] Improve transforms for `(mul X, Y)` -> `(shl X, log2(Y)`

Using the more robust log2 search allows us to fold more cases (same
logic as exists for idiv/irem).

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D146347

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
    llvm/test/Transforms/InstCombine/mul-pow2.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 5d4dd2fdac66d..9d0e171f98d43 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -185,6 +185,9 @@ static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands,
   return nullptr;
 }
 
+static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
+                       bool DoFold);
+
 Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
   if (Value *V =
@@ -478,6 +481,26 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
                                     m_c_UMin(m_Deferred(X), m_Deferred(Y))))))
     return BinaryOperator::CreateWithCopiedFlags(Instruction::Mul, X, Y, &I);
 
+  // (mul Op0 Op1):
+  //    if Log2(Op0) folds away ->
+  //        (shl Op1, Log2(Op0))
+  //    if Log2(Op1) folds away ->
+  //        (shl Op0, Log2(Op1))
+  if (takeLog2(Builder, Op0, /*Depth*/ 0, /*DoFold*/ false)) {
+    Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*DoFold*/ true);
+    BinaryOperator *Shl = BinaryOperator::CreateShl(Op1, Res);
+    // We can only propegate nuw flag.
+    Shl->setHasNoUnsignedWrap(HasNUW);
+    return Shl;
+  }
+  if (takeLog2(Builder, Op1, /*Depth*/ 0, /*DoFold*/ false)) {
+    Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*DoFold*/ true);
+    BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, Res);
+    // We can only propegate nuw flag.
+    Shl->setHasNoUnsignedWrap(HasNUW);
+    return Shl;
+  }
+
   bool Changed = false;
   if (!HasNSW && willNotOverflowSignedMul(Op0, Op1, I)) {
     Changed = true;

diff  --git a/llvm/test/Transforms/InstCombine/mul-pow2.ll b/llvm/test/Transforms/InstCombine/mul-pow2.ll
index 58dd6a23c076e..5617c74647d28 100644
--- a/llvm/test/Transforms/InstCombine/mul-pow2.ll
+++ b/llvm/test/Transforms/InstCombine/mul-pow2.ll
@@ -3,8 +3,8 @@
 declare void @use_i8(i8)
 define i8 @mul_selectp2_x(i8 %x, i1 %c) {
 ; CHECK-LABEL: @mul_selectp2_x(
-; CHECK-NEXT:    [[S:%.*]] = select i1 [[C:%.*]], i8 2, i8 4
-; CHECK-NEXT:    [[R:%.*]] = mul i8 [[S]], [[X:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[C:%.*]], i8 1, i8 2
+; CHECK-NEXT:    [[R:%.*]] = shl i8 [[X:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %s = select i1 %c, i8 2, i8 4
@@ -15,8 +15,8 @@ define i8 @mul_selectp2_x(i8 %x, i1 %c) {
 
 define i8 @mul_selectp2_x_propegate_nuw(i8 %x, i1 %c) {
 ; CHECK-LABEL: @mul_selectp2_x_propegate_nuw(
-; CHECK-NEXT:    [[S:%.*]] = select i1 [[C:%.*]], i8 2, i8 4
-; CHECK-NEXT:    [[R:%.*]] = mul nuw nsw i8 [[S]], [[X:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[C:%.*]], i8 1, i8 2
+; CHECK-NEXT:    [[R:%.*]] = shl nuw i8 [[X:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %s = select i1 %c, i8 2, i8 4
@@ -28,7 +28,8 @@ define i8 @mul_selectp2_x_propegate_nuw(i8 %x, i1 %c) {
 define i8 @mul_selectp2_x_multiuse_fixme(i8 %x, i1 %c) {
 ; CHECK-LABEL: @mul_selectp2_x_multiuse_fixme(
 ; CHECK-NEXT:    [[S:%.*]] = select i1 [[C:%.*]], i8 2, i8 4
-; CHECK-NEXT:    [[R:%.*]] = mul i8 [[S]], [[X:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[C]], i8 1, i8 2
+; CHECK-NEXT:    [[R:%.*]] = shl i8 [[X:%.*]], [[TMP1]]
 ; CHECK-NEXT:    call void @use_i8(i8 [[S]])
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
@@ -40,9 +41,8 @@ define i8 @mul_selectp2_x_multiuse_fixme(i8 %x, i1 %c) {
 
 define i8 @mul_selectp2_x_non_const(i8 %x, i1 %c, i8 %yy) {
 ; CHECK-LABEL: @mul_selectp2_x_non_const(
-; CHECK-NEXT:    [[Y:%.*]] = shl nuw i8 1, [[YY:%.*]]
-; CHECK-NEXT:    [[S:%.*]] = select i1 [[C:%.*]], i8 2, i8 [[Y]]
-; CHECK-NEXT:    [[R:%.*]] = mul i8 [[S]], [[X:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[C:%.*]], i8 1, i8 [[YY:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = shl i8 [[X:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %y = shl i8 1, %yy
@@ -54,8 +54,8 @@ define i8 @mul_selectp2_x_non_const(i8 %x, i1 %c, i8 %yy) {
 define i8 @mul_selectp2_x_non_const_multiuse(i8 %x, i1 %c, i8 %yy) {
 ; CHECK-LABEL: @mul_selectp2_x_non_const_multiuse(
 ; CHECK-NEXT:    [[Y:%.*]] = shl nuw i8 1, [[YY:%.*]]
-; CHECK-NEXT:    [[S:%.*]] = select i1 [[C:%.*]], i8 2, i8 [[Y]]
-; CHECK-NEXT:    [[R:%.*]] = mul i8 [[S]], [[X:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[C:%.*]], i8 1, i8 [[YY]]
+; CHECK-NEXT:    [[R:%.*]] = shl i8 [[X:%.*]], [[TMP1]]
 ; CHECK-NEXT:    call void @use_i8(i8 [[Y]])
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
@@ -69,8 +69,8 @@ define i8 @mul_selectp2_x_non_const_multiuse(i8 %x, i1 %c, i8 %yy) {
 define i8 @mul_x_selectp2(i8 %xx, i1 %c) {
 ; CHECK-LABEL: @mul_x_selectp2(
 ; CHECK-NEXT:    [[X:%.*]] = mul i8 [[XX:%.*]], [[XX]]
-; CHECK-NEXT:    [[S:%.*]] = select i1 [[C:%.*]], i8 8, i8 1
-; CHECK-NEXT:    [[R:%.*]] = mul i8 [[X]], [[S]]
+; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[C:%.*]], i8 3, i8 0
+; CHECK-NEXT:    [[R:%.*]] = shl i8 [[X]], [[TMP1]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %x = mul i8 %xx, %xx
@@ -93,8 +93,8 @@ define i8 @mul_select_nonp2_x_fail(i8 %x, i1 %c) {
 define <2 x i8> @mul_x_selectp2_vec(<2 x i8> %xx, i1 %c) {
 ; CHECK-LABEL: @mul_x_selectp2_vec(
 ; CHECK-NEXT:    [[X:%.*]] = mul <2 x i8> [[XX:%.*]], [[XX]]
-; CHECK-NEXT:    [[S:%.*]] = select i1 [[C:%.*]], <2 x i8> <i8 8, i8 16>, <2 x i8> <i8 4, i8 1>
-; CHECK-NEXT:    [[R:%.*]] = mul <2 x i8> [[X]], [[S]]
+; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[C:%.*]], <2 x i8> <i8 3, i8 4>, <2 x i8> <i8 2, i8 0>
+; CHECK-NEXT:    [[R:%.*]] = shl <2 x i8> [[X]], [[TMP1]]
 ; CHECK-NEXT:    ret <2 x i8> [[R]]
 ;
   %x = mul <2 x i8> %xx, %xx


        


More information about the llvm-commits mailing list