[llvm] bfe2f5d - [InstCombine] Fix buggy `(mul X, Y)` -> `(shl X, Log2(Y))` transform PR62175

Noah Goldstein via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 18 15:18:12 PDT 2023


Author: Noah Goldstein
Date: 2023-04-18T17:17:48-05:00
New Revision: bfe2f5d38bb14bf7ce4f44d3de558fbc076bdc1a

URL: https://github.com/llvm/llvm-project/commit/bfe2f5d38bb14bf7ce4f44d3de558fbc076bdc1a
DIFF: https://github.com/llvm/llvm-project/commit/bfe2f5d38bb14bf7ce4f44d3de558fbc076bdc1a.diff

LOG: [InstCombine] Fix buggy `(mul X, Y)` -> `(shl X, Log2(Y))` transform PR62175

Bug was because we recognized patterns like `(shl 4, Z)` as a power of
2 we could take Log2 of (`2 + Z`), but doing `(shl X, (2 + Z))` can
cause a poison shift.
    https://alive2.llvm.org/ce/z/yuJm_k

The fix is to verify that `Log2(Y)` will be a non-poisonous shift
amount. We can do this with:
    `nsw` flag:
        - https://alive2.llvm.org/ce/z/yyyJBr
        - https://alive2.llvm.org/ce/z/YgubD_
    `nuw` flag:
        - https://alive2.llvm.org/ce/z/-4mpyV
        - https://alive2.llvm.org/ce/z/a6ik6r
    Prove `Y != 0`:
        - https://alive2.llvm.org/ce/z/ced4su
        - https://alive2.llvm.org/ce/z/X-JJHb

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D148609

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
    llvm/test/Transforms/InstCombine/div-shift.ll
    llvm/test/Transforms/InstCombine/mul-pow2.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 9d0e171f98d43..19f2d2fde6fea 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -186,7 +186,7 @@ static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands,
 }
 
 static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
-                       bool DoFold);
+                       bool AssumeNonZero, bool DoFold);
 
 Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
@@ -486,15 +486,19 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
   //        (shl Op1, Log2(Op0))
   //    if Log2(Op1) folds away ->
   //        (shl Op0, Log2(Op1))
-  if (takeLog2(Builder, Op0, /*Depth*/ 0, /*DoFold*/ false)) {
-    Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*DoFold*/ true);
+  if (takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false,
+               /*DoFold*/ false)) {
+    Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false,
+                          /*DoFold*/ true);
     BinaryOperator *Shl = BinaryOperator::CreateShl(Op1, Res);
     // We can only propegate nuw flag.
     Shl->setHasNoUnsignedWrap(HasNUW);
     return Shl;
   }
-  if (takeLog2(Builder, Op1, /*Depth*/ 0, /*DoFold*/ false)) {
-    Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*DoFold*/ true);
+  if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false,
+               /*DoFold*/ false)) {
+    Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false,
+                          /*DoFold*/ true);
     BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, Res);
     // We can only propegate nuw flag.
     Shl->setHasNoUnsignedWrap(HasNUW);
@@ -1181,7 +1185,7 @@ static const unsigned MaxDepth = 6;
 // actual instructions, otherwise return a non-null dummy value. Return nullptr
 // on failure.
 static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
-                       bool DoFold) {
+                       bool AssumeNonZero, bool DoFold) {
   auto IfFold = [DoFold](function_ref<Value *()> Fn) {
     if (!DoFold)
       return reinterpret_cast<Value *>(-1);
@@ -1207,14 +1211,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
   // FIXME: Require one use?
   Value *X, *Y;
   if (match(Op, m_ZExt(m_Value(X))))
-    if (Value *LogX = takeLog2(Builder, X, Depth, DoFold))
+    if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
       return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); });
 
   // log2(X << Y) -> log2(X) + Y
   // FIXME: Require one use unless X is 1?
-  if (match(Op, m_Shl(m_Value(X), m_Value(Y))))
-    if (Value *LogX = takeLog2(Builder, X, Depth, DoFold))
-      return IfFold([&]() { return Builder.CreateAdd(LogX, Y); });
+  if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) {
+    auto *BO = cast<OverflowingBinaryOperator>(Op);
+    // nuw will be set if the `shl` is trivially non-zero.
+    if (AssumeNonZero || BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap())
+      if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
+        return IfFold([&]() { return Builder.CreateAdd(LogX, Y); });
+  }
 
   // log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y)
   // FIXME: missed optimization: if one of the hands of select is/contains
@@ -1222,8 +1230,10 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
   // FIXME: can both hands contain undef?
   // FIXME: Require one use?
   if (SelectInst *SI = dyn_cast<SelectInst>(Op))
-    if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, DoFold))
-      if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, DoFold))
+    if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth,
+                               AssumeNonZero, DoFold))
+      if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth,
+                                 AssumeNonZero, DoFold))
         return IfFold([&]() {
           return Builder.CreateSelect(SI->getOperand(0), LogX, LogY);
         });
@@ -1231,13 +1241,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
   // log2(umin(X, Y)) -> umin(log2(X), log2(Y))
   // log2(umax(X, Y)) -> umax(log2(X), log2(Y))
   auto *MinMax = dyn_cast<MinMaxIntrinsic>(Op);
-  if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned())
-    if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, DoFold))
-      if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, DoFold))
+  if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) {
+    // Use AssumeNonZero as false here. Otherwise we can hit case where
+    // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
+    if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth,
+                               /*AssumeNonZero*/ false, DoFold))
+      if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth,
+                                 /*AssumeNonZero*/ false, DoFold))
         return IfFold([&]() {
-          return Builder.CreateBinaryIntrinsic(
-              MinMax->getIntrinsicID(), LogX, LogY);
+          return Builder.CreateBinaryIntrinsic(MinMax->getIntrinsicID(), LogX,
+                                               LogY);
         });
+  }
 
   return nullptr;
 }
@@ -1357,8 +1372,10 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
   }
 
   // Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away.
-  if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) {
-    Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true);
+  if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ true,
+               /*DoFold*/ false)) {
+    Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0,
+                          /*AssumeNonZero*/ true, /*DoFold*/ true);
     return replaceInstUsesWith(
         I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact()));
   }

diff  --git a/llvm/test/Transforms/InstCombine/div-shift.ll b/llvm/test/Transforms/InstCombine/div-shift.ll
index efdaa0ce9d059..76c5328dc8499 100644
--- a/llvm/test/Transforms/InstCombine/div-shift.ll
+++ b/llvm/test/Transforms/InstCombine/div-shift.ll
@@ -1000,3 +1000,28 @@ define i8 @udiv_shl_nuw_divisor(i8 %x, i8 %y, i8 %z) {
   %d = udiv i8 %x, %s
   ret i8 %d
 }
+
+define i8 @udiv_fail_shl_overflow(i8 %x, i8 %y) {
+; CHECK-LABEL: @udiv_fail_shl_overflow(
+; CHECK-NEXT:    [[SHL:%.*]] = shl i8 2, [[Y:%.*]]
+; CHECK-NEXT:    [[MIN:%.*]] = call i8 @llvm.umax.i8(i8 [[SHL]], i8 1)
+; CHECK-NEXT:    [[MUL:%.*]] = udiv i8 [[X:%.*]], [[MIN]]
+; CHECK-NEXT:    ret i8 [[MUL]]
+;
+  %shl = shl i8 2, %y
+  %min = call i8 @llvm.umax.i8(i8 %shl, i8 1)
+  %mul = udiv i8 %x, %min
+  ret i8 %mul
+}
+
+define i8 @udiv_shl_no_overflow(i8 %x, i8 %y) {
+; CHECK-LABEL: @udiv_shl_no_overflow(
+; CHECK-NEXT:    [[TMP1:%.*]] = add i8 [[Y:%.*]], 1
+; CHECK-NEXT:    [[MUL1:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
+; CHECK-NEXT:    ret i8 [[MUL1]]
+;
+  %shl = shl nuw i8 2, %y
+  %min = call i8 @llvm.umax.i8(i8 %shl, i8 1)
+  %mul = udiv i8 %x, %min
+  ret i8 %mul
+}

diff  --git a/llvm/test/Transforms/InstCombine/mul-pow2.ll b/llvm/test/Transforms/InstCombine/mul-pow2.ll
index 5617c74647d28..c16fd710f309b 100644
--- a/llvm/test/Transforms/InstCombine/mul-pow2.ll
+++ b/llvm/test/Transforms/InstCombine/mul-pow2.ll
@@ -102,3 +102,37 @@ define <2 x i8> @mul_x_selectp2_vec(<2 x i8> %xx, i1 %c) {
   %r = mul <2 x i8> %x, %s
   ret <2 x i8> %r
 }
+
+
+define i8 @shl_add_log_may_cause_poison_pr62175_fail(i8 %x, i8 %y) {
+; CHECK-LABEL: @shl_add_log_may_cause_poison_pr62175_fail(
+; CHECK-NEXT:    [[SHL:%.*]] = shl i8 4, [[X:%.*]]
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[SHL]], [[Y:%.*]]
+; CHECK-NEXT:    ret i8 [[MUL]]
+;
+  %shl = shl i8 4, %x
+  %mul = mul i8 %y, %shl
+  ret i8 %mul
+}
+
+define i8 @shl_add_log_may_cause_poison_pr62175_with_nuw(i8 %x, i8 %y) {
+; CHECK-LABEL: @shl_add_log_may_cause_poison_pr62175_with_nuw(
+; CHECK-NEXT:    [[TMP1:%.*]] = add i8 [[X:%.*]], 2
+; CHECK-NEXT:    [[MUL:%.*]] = shl i8 [[Y:%.*]], [[TMP1]]
+; CHECK-NEXT:    ret i8 [[MUL]]
+;
+  %shl = shl nuw i8 4, %x
+  %mul = mul i8 %y, %shl
+  ret i8 %mul
+}
+
+define i8 @shl_add_log_may_cause_poison_pr62175_with_nsw(i8 %x, i8 %y) {
+; CHECK-LABEL: @shl_add_log_may_cause_poison_pr62175_with_nsw(
+; CHECK-NEXT:    [[TMP1:%.*]] = add i8 [[X:%.*]], 2
+; CHECK-NEXT:    [[MUL:%.*]] = shl i8 [[Y:%.*]], [[TMP1]]
+; CHECK-NEXT:    ret i8 [[MUL]]
+;
+  %shl = shl nsw i8 4, %x
+  %mul = mul i8 %y, %shl
+  ret i8 %mul
+}


        


More information about the llvm-commits mailing list