[llvm] 3f3356b - [InstCombine] allow vector splats for add+xor --> shifts

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 11 06:09:42 PDT 2020


Author: Sanjay Patel
Date: 2020-10-11T09:04:24-04:00
New Revision: 3f3356bdd9c7188530f6582b4a407469131ae679

URL: https://github.com/llvm/llvm-project/commit/3f3356bdd9c7188530f6582b4a407469131ae679
DIFF: https://github.com/llvm/llvm-project/commit/3f3356bdd9c7188530f6582b4a407469131ae679.diff

LOG: [InstCombine] allow vector splats for add+xor --> shifts

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
    llvm/test/Transforms/InstCombine/signext.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 2dd5180378a1..4987ba7091ae 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -936,6 +936,25 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
       if ((*C2 | LHSKnown.Zero).isAllOnesValue())
         return BinaryOperator::CreateSub(ConstantInt::get(Ty, *C2 + *C), X);
     }
+
+    // Look for a math+logic pattern that corresponds to sext-in-register of a
+    // value with cleared high bits. Convert that into a pair of shifts:
+    // add (xor X, 0x80), 0xF..F80 --> (X << ShAmtC) >>s ShAmtC
+    // add (xor X, 0xF..F80), 0x80 --> (X << ShAmtC) >>s ShAmtC
+    if (Op0->hasOneUse() && *C2 == -(*C)) {
+      unsigned BitWidth = Ty->getScalarSizeInBits();
+      unsigned ShAmt = 0;
+      if (C->isPowerOf2())
+        ShAmt = BitWidth - C->logBase2() - 1;
+      else if (C2->isPowerOf2())
+        ShAmt = BitWidth - C2->logBase2() - 1;
+      if (ShAmt && MaskedValueIsZero(X, APInt::getHighBitsSet(BitWidth, ShAmt),
+                                     0, &Add)) {
+        Constant *ShAmtC = ConstantInt::get(Ty, ShAmt);
+        Value *NewShl = Builder.CreateShl(X, ShAmtC, "sext");
+        return BinaryOperator::CreateAShr(NewShl, ShAmtC);
+      }
+    }
   }
 
   if (C->isOneValue() && Op0->hasOneUse()) {
@@ -1284,39 +1303,8 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   if (Instruction *X = foldNoWrapAdd(I, Builder))
     return X;
 
-  // FIXME: This should be moved into the above helper function to allow these
-  // transforms for general constant or constant splat vectors.
   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
   Type *Ty = I.getType();
-  if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
-    Value *XorLHS = nullptr; ConstantInt *XorRHS = nullptr;
-    if (match(LHS, m_OneUse(m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS))))) {
-      unsigned TySizeBits = Ty->getScalarSizeInBits();
-      const APInt &RHSVal = CI->getValue();
-      unsigned ExtendAmt = 0;
-      // If we have ADD(XOR(AND(X, 0xFF), 0x80), 0xF..F80), it's a sext.
-      // If we have ADD(XOR(AND(X, 0xFF), 0xF..F80), 0x80), it's a sext.
-      if (XorRHS->getValue() == -RHSVal) {
-        if (RHSVal.isPowerOf2())
-          ExtendAmt = TySizeBits - RHSVal.logBase2() - 1;
-        else if (XorRHS->getValue().isPowerOf2())
-          ExtendAmt = TySizeBits - XorRHS->getValue().logBase2() - 1;
-      }
-
-      if (ExtendAmt) {
-        APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt);
-        if (!MaskedValueIsZero(XorLHS, Mask, 0, &I))
-          ExtendAmt = 0;
-      }
-
-      if (ExtendAmt) {
-        Constant *ShAmt = ConstantInt::get(Ty, ExtendAmt);
-        Value *NewShl = Builder.CreateShl(XorLHS, ShAmt, "sext");
-        return BinaryOperator::CreateAShr(NewShl, ShAmt);
-      }
-    }
-  }
-
   if (Ty->isIntOrIntVectorTy(1))
     return BinaryOperator::CreateXor(LHS, RHS);
 

diff  --git a/llvm/test/Transforms/InstCombine/signext.ll b/llvm/test/Transforms/InstCombine/signext.ll
index 4faf4e384874..447e54849df6 100644
--- a/llvm/test/Transforms/InstCombine/signext.ll
+++ b/llvm/test/Transforms/InstCombine/signext.ll
@@ -34,9 +34,8 @@ define i32 @sextinreg_extra_use(i32 %x) {
 
 define <2 x i32> @sextinreg_splat(<2 x i32> %x) {
 ; CHECK-LABEL: @sextinreg_splat(
-; CHECK-NEXT:    [[T1:%.*]] = and <2 x i32> [[X:%.*]], <i32 65535, i32 65535>
-; CHECK-NEXT:    [[T2:%.*]] = xor <2 x i32> [[T1]], <i32 -32768, i32 -32768>
-; CHECK-NEXT:    [[T3:%.*]] = add nsw <2 x i32> [[T2]], <i32 32768, i32 32768>
+; CHECK-NEXT:    [[SEXT:%.*]] = shl <2 x i32> [[X:%.*]], <i32 16, i32 16>
+; CHECK-NEXT:    [[T3:%.*]] = ashr exact <2 x i32> [[SEXT]], <i32 16, i32 16>
 ; CHECK-NEXT:    ret <2 x i32> [[T3]]
 ;
   %t1 = and <2 x i32> %x, <i32 65535, i32 65535>
@@ -59,9 +58,8 @@ define i32 @sextinreg_alt(i32 %x) {
 
 define <2 x i32> @sextinreg_alt_splat(<2 x i32> %x) {
 ; CHECK-LABEL: @sextinreg_alt_splat(
-; CHECK-NEXT:    [[T1:%.*]] = and <2 x i32> [[X:%.*]], <i32 65535, i32 65535>
-; CHECK-NEXT:    [[T2:%.*]] = xor <2 x i32> [[T1]], <i32 32768, i32 32768>
-; CHECK-NEXT:    [[T3:%.*]] = add nsw <2 x i32> [[T2]], <i32 -32768, i32 -32768>
+; CHECK-NEXT:    [[SEXT:%.*]] = shl <2 x i32> [[X:%.*]], <i32 16, i32 16>
+; CHECK-NEXT:    [[T3:%.*]] = ashr exact <2 x i32> [[SEXT]], <i32 16, i32 16>
 ; CHECK-NEXT:    ret <2 x i32> [[T3]]
 ;
   %t1 = and <2 x i32> %x, <i32 65535, i32 65535>
@@ -121,9 +119,8 @@ define i32 @sextinreg2(i32 %x) {
 
 define <2 x i32> @sextinreg2_splat(<2 x i32> %x) {
 ; CHECK-LABEL: @sextinreg2_splat(
-; CHECK-NEXT:    [[T1:%.*]] = and <2 x i32> [[X:%.*]], <i32 255, i32 255>
-; CHECK-NEXT:    [[T2:%.*]] = xor <2 x i32> [[T1]], <i32 128, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = add nsw <2 x i32> [[T2]], <i32 -128, i32 -128>
+; CHECK-NEXT:    [[SEXT:%.*]] = shl <2 x i32> [[X:%.*]], <i32 24, i32 24>
+; CHECK-NEXT:    [[T3:%.*]] = ashr exact <2 x i32> [[SEXT]], <i32 24, i32 24>
 ; CHECK-NEXT:    ret <2 x i32> [[T3]]
 ;
   %t1 = and <2 x i32> %x, <i32 255, i32 255>
@@ -184,9 +181,7 @@ define i32 @ashr(i32 %x) {
 
 define <2 x i32> @ashr_splat(<2 x i32> %x) {
 ; CHECK-LABEL: @ashr_splat(
-; CHECK-NEXT:    [[SHR:%.*]] = lshr <2 x i32> [[X:%.*]], <i32 5, i32 5>
-; CHECK-NEXT:    [[XOR:%.*]] = xor <2 x i32> [[SHR]], <i32 67108864, i32 67108864>
-; CHECK-NEXT:    [[SUB:%.*]] = add nsw <2 x i32> [[XOR]], <i32 -67108864, i32 -67108864>
+; CHECK-NEXT:    [[SUB:%.*]] = ashr <2 x i32> [[X:%.*]], <i32 5, i32 5>
 ; CHECK-NEXT:    ret <2 x i32> [[SUB]]
 ;
   %shr = lshr <2 x i32> %x, <i32 5, i32 5>


        


More information about the llvm-commits mailing list