[llvm] 98fde34 - [InstCombine] reduce redundant code for shl-binop folds

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 28 14:12:43 PDT 2021


Author: Sanjay Patel
Date: 2021-09-28T17:06:45-04:00
New Revision: 98fde3489a6d9ed3fa409623036e5ba5b99f1404

URL: https://github.com/llvm/llvm-project/commit/98fde3489a6d9ed3fa409623036e5ba5b99f1404
DIFF: https://github.com/llvm/llvm-project/commit/98fde3489a6d9ed3fa409623036e5ba5b99f1404.diff

LOG: [InstCombine] reduce redundant code for shl-binop folds

This is NFCI (no-functional-change-intended), but there
are benign diffs possible with commutable ops as seen in
the test diffs.

The transforms were repeated for the commutative opcodes,
but that should not be necessary if we canonicalize the
patterns that we're matching. If both operands of the
binop match, that should get folded eventually.

The transform that starts with a mask op seems to
over-constrain the use checks, so that could be a
potential enhancement.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
    llvm/test/Transforms/InstCombine/shl-bo.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index da650ec2f718..8d4113835dc7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -890,78 +890,61 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
       return new TruncInst(And, Ty);
     }
 
-    BinaryOperator *Op0BO;
-    if (match(Op0, m_OneUse(m_BinOp(Op0BO)))) {
-      // Turn ((X >> C) + Y) << C  ->  (X + (Y << C)) & (~0 << C)
-      Value *V1;
-      const APInt *CC;
-      switch (Op0BO->getOpcode()) {
+    // If we have an opposite shift by the same amount, we may be able to
+    // reorder binops and shifts to eliminate math/logic.
+    auto isSuitableBinOpcode = [](Instruction::BinaryOps BinOpcode) {
+      switch (BinOpcode) {
       default:
-        break;
+        return false;
       case Instruction::Add:
       case Instruction::And:
       case Instruction::Or:
-      case Instruction::Xor: {
-        // These operators commute.
-        // Turn (Y + (X >> C)) << C  ->  (X + (Y << C)) & (~0 << C)
-        if (Op0BO->getOperand(1)->hasOneUse() &&
-            match(Op0BO->getOperand(1), m_Shr(m_Value(V1), m_Specific(Op1)))) {
-          Value *YS = // (Y << C)
-              Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
-          // (X + (Y << C))
-          Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), YS, V1,
-                                         Op0BO->getOperand(1)->getName());
-          unsigned Op1Val = C->getLimitedValue(BitWidth);
-          APInt Bits = APInt::getHighBitsSet(BitWidth, BitWidth - Op1Val);
-          Constant *Mask = ConstantInt::get(Ty, Bits);
-          return BinaryOperator::CreateAnd(X, Mask);
-        }
-
-        // Turn (Y + ((X >> C) & CC)) << C  ->  ((X & (CC << C)) + (Y << C))
-        Value *Op0BOOp1 = Op0BO->getOperand(1);
-        if (Op0BOOp1->hasOneUse() &&
-            match(Op0BOOp1, m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
-                                  m_APInt(CC)))) {
-          Value *YS = // (Y << C)
-              Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
-          // X & (CC << C)
-          Value *XM = Builder.CreateAnd(V1, ConstantInt::get(Ty, CC->shl(*C)),
-                                        V1->getName() + ".mask");
-          return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM);
-        }
-        LLVM_FALLTHROUGH;
+      case Instruction::Xor:
+      case Instruction::Sub:
+        // NOTE: Sub is not commutable and the tranforms below may not be valid
+        //       when the shift-right is operand 1 (RHS) of the sub.
+        return true;
       }
-
-      case Instruction::Sub: {
-        // Turn ((X >> C) + Y) << C  ->  (X + (Y << C)) & (~0 << C)
-        if (Op0BO->getOperand(0)->hasOneUse() &&
-            match(Op0BO->getOperand(0), m_Shr(m_Value(V1), m_Specific(Op1)))) {
-          Value *YS = // (Y << C)
-              Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
-          // (X + (Y << C))
-          Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), V1, YS,
-                                         Op0BO->getOperand(0)->getName());
-          unsigned Op1Val = C->getLimitedValue(BitWidth);
-          APInt Bits = APInt::getHighBitsSet(BitWidth, BitWidth - Op1Val);
-          Constant *Mask = ConstantInt::get(Ty, Bits);
-          return BinaryOperator::CreateAnd(X, Mask);
-        }
-
-        // Turn (((X >> C)&CC) + Y) << C  ->  (X + (Y << C)) & (CC << C)
-        if (Op0BO->getOperand(0)->hasOneUse() &&
-            match(Op0BO->getOperand(0),
-                  m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
-                        m_APInt(CC)))) {
-          Value *YS = // (Y << C)
-              Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
-          // X & (CC << C)
-          Value *XM = Builder.CreateAnd(V1, ConstantInt::get(Ty, CC->shl(*C)),
-                                        V1->getName() + ".mask");
-          return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS);
-        }
-
-        break;
+    };
+    BinaryOperator *Op0BO;
+    if (match(Op0, m_OneUse(m_BinOp(Op0BO))) &&
+        isSuitableBinOpcode(Op0BO->getOpcode())) {
+      // Commute so shift-right is on LHS of the binop.
+      // (Y bop (X >> C)) << C         ->  ((X >> C) bop Y) << C
+      // (Y bop ((X >> C) & CC)) << C  ->  (((X >> C) & CC) bop Y) << C
+      Value *Shr = Op0BO->getOperand(0);
+      Value *Y = Op0BO->getOperand(1);
+      Value *X;
+      const APInt *CC;
+      if (Op0BO->isCommutative() && Y->hasOneUse() &&
+          (match(Y, m_Shr(m_Value(), m_Specific(Op1))) ||
+           match(Y, m_And(m_OneUse(m_Shr(m_Value(), m_Specific(Op1))),
+                          m_APInt(CC)))))
+        std::swap(Shr, Y);
+
+      // ((X >> C) bop Y) << C  ->  (X bop (Y << C)) & (~0 << C)
+      if (match(Shr, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) {
+        // Y << C
+        Value *YS = Builder.CreateShl(Y, Op1, Op0BO->getName());
+        // (X bop (Y << C))
+        Value *B =
+            Builder.CreateBinOp(Op0BO->getOpcode(), X, YS, Shr->getName());
+        unsigned Op1Val = C->getLimitedValue(BitWidth);
+        APInt Bits = APInt::getHighBitsSet(BitWidth, BitWidth - Op1Val);
+        Constant *Mask = ConstantInt::get(Ty, Bits);
+        return BinaryOperator::CreateAnd(B, Mask);
       }
+
+      // (((X >> C) & CC) bop Y) << C  ->  (X & (CC << C)) bop (Y << C)
+      if (match(Shr,
+                m_OneUse(m_And(m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))),
+                               m_APInt(CC))))) {
+        // Y << C
+        Value *YS = Builder.CreateShl(Y, Op1, Op0BO->getName());
+        // X & (CC << C)
+        Value *M = Builder.CreateAnd(X, ConstantInt::get(Ty, CC->shl(*C)),
+                                     X->getName() + ".mask");
+        return BinaryOperator::Create(Op0BO->getOpcode(), M, YS);
       }
     }
 

diff  --git a/llvm/test/Transforms/InstCombine/shl-bo.ll b/llvm/test/Transforms/InstCombine/shl-bo.ll
index 2ac2211f5ab2..5bdc22671ed7 100644
--- a/llvm/test/Transforms/InstCombine/shl-bo.ll
+++ b/llvm/test/Transforms/InstCombine/shl-bo.ll
@@ -186,7 +186,7 @@ define i8 @lshr_and_add(i8 %a, i8 %y)  {
 ; CHECK-NEXT:    [[X:%.*]] = srem i8 [[A:%.*]], 42
 ; CHECK-NEXT:    [[B1:%.*]] = shl i8 [[X]], 3
 ; CHECK-NEXT:    [[Y_MASK:%.*]] = and i8 [[Y:%.*]], 96
-; CHECK-NEXT:    [[L:%.*]] = add i8 [[B1]], [[Y_MASK]]
+; CHECK-NEXT:    [[L:%.*]] = add i8 [[Y_MASK]], [[B1]]
 ; CHECK-NEXT:    ret i8 [[L]]
 ;
   %x = srem i8 %a, 42 ; thwart complexity-based canonicalization
@@ -267,7 +267,7 @@ define <2 x i8> @lshr_and_and_commute_splat(<2 x i8> %a, <2 x i8> %y)  {
 ; CHECK-NEXT:    [[X:%.*]] = srem <2 x i8> [[A:%.*]], <i8 42, i8 42>
 ; CHECK-NEXT:    [[B1:%.*]] = shl <2 x i8> [[X]], <i8 2, i8 2>
 ; CHECK-NEXT:    [[Y_MASK:%.*]] = and <2 x i8> [[Y:%.*]], <i8 52, i8 52>
-; CHECK-NEXT:    [[L:%.*]] = and <2 x i8> [[B1]], [[Y_MASK]]
+; CHECK-NEXT:    [[L:%.*]] = and <2 x i8> [[Y_MASK]], [[B1]]
 ; CHECK-NEXT:    ret <2 x i8> [[L]]
 ;
   %x = srem <2 x i8> %a, <i8 42, i8 42> ; thwart complexity-based canonicalization
@@ -283,7 +283,7 @@ define i8 @lshr_and_or(i8 %a, i8 %y)  {
 ; CHECK-NEXT:    [[X:%.*]] = srem i8 [[A:%.*]], 42
 ; CHECK-NEXT:    [[B1:%.*]] = shl i8 [[X]], 2
 ; CHECK-NEXT:    [[Y_MASK:%.*]] = and i8 [[Y:%.*]], 52
-; CHECK-NEXT:    [[L:%.*]] = or i8 [[B1]], [[Y_MASK]]
+; CHECK-NEXT:    [[L:%.*]] = or i8 [[Y_MASK]], [[B1]]
 ; CHECK-NEXT:    ret i8 [[L]]
 ;
   %x = srem i8 %a, 42 ; thwart complexity-based canonicalization
@@ -331,7 +331,7 @@ define <2 x i8> @lshr_and_xor_commute_splat(<2 x i8> %a, <2 x i8> %y)  {
 ; CHECK-NEXT:    [[X:%.*]] = srem <2 x i8> [[A:%.*]], <i8 42, i8 42>
 ; CHECK-NEXT:    [[B1:%.*]] = shl <2 x i8> [[X]], <i8 2, i8 2>
 ; CHECK-NEXT:    [[Y_MASK:%.*]] = and <2 x i8> [[Y:%.*]], <i8 52, i8 52>
-; CHECK-NEXT:    [[L:%.*]] = xor <2 x i8> [[B1]], [[Y_MASK]]
+; CHECK-NEXT:    [[L:%.*]] = xor <2 x i8> [[Y_MASK]], [[B1]]
 ; CHECK-NEXT:    ret <2 x i8> [[L]]
 ;
   %x = srem <2 x i8> %a, <i8 42, i8 42> ; thwart complexity-based canonicalization


        


More information about the llvm-commits mailing list