[llvm] 5c91aa6 - [InstCombine] Fold or(zext(bswap(x)),shl(zext(bswap(y)),bw/2)) -> bswap(or(zext(x),shl(zext(y), bw/2))

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue May 5 04:33:01 PDT 2020


Author: Simon Pilgrim
Date: 2020-05-05T12:30:10+01:00
New Revision: 5c91aa660386ea22e0d38eb0de4c26d62073ccb3

URL: https://github.com/llvm/llvm-project/commit/5c91aa660386ea22e0d38eb0de4c26d62073ccb3
DIFF: https://github.com/llvm/llvm-project/commit/5c91aa660386ea22e0d38eb0de4c26d62073ccb3.diff

LOG: [InstCombine] Fold or(zext(bswap(x)),shl(zext(bswap(y)),bw/2)) -> bswap(or(zext(x),shl(zext(y), bw/2))

This adds a general combine that can be used to fold:

  or(zext(OP(x)), shl(zext(OP(y)),bw/2))
-->
  OP(or(zext(x), shl(zext(y),bw/2)))

Allowing us to widen 'concat-able' style or+zext patterns - I've just set this up for BSWAP but we could use this for other similar ops (BITREVERSE for instance).

We already do something similar for bitop(bswap(x),bswap(y)) --> bswap(bitop(x,y))

Fixes PR45715

Reviewed By: @lebedev.ri

Differential Revision: https://reviews.llvm.org/D79041

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
    llvm/test/Transforms/InstCombine/or-concat.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 6cc6dcdd748a..a4d86d751c2f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2132,6 +2132,49 @@ static Instruction *matchRotate(Instruction &Or) {
   return IntrinsicInst::Create(F, { ShVal, ShVal, ShAmt });
 }
 
+/// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns.
+static Instruction *matchOrConcat(Instruction &Or,
+                                  InstCombiner::BuilderTy &Builder) {
+  assert(Or.getOpcode() == Instruction::Or && "bswap requires an 'or'");
+  Value *Op0 = Or.getOperand(0), *Op1 = Or.getOperand(1);
+  Type *Ty = Or.getType();
+
+  unsigned Width = Ty->getScalarSizeInBits();
+  if ((Width & 1) != 0)
+    return nullptr;
+  unsigned HalfWidth = Width / 2;
+
+  // Canonicalize zext (lower half) to LHS.
+  if (!isa<ZExtInst>(Op0))
+    std::swap(Op0, Op1);
+
+  // Find lower/upper half.
+  Value *LowerSrc, *ShlVal, *UpperSrc;
+  const APInt *C;
+  if (!match(Op0, m_OneUse(m_ZExt(m_Value(LowerSrc)))) ||
+      !match(Op1, m_OneUse(m_Shl(m_Value(ShlVal), m_APInt(C)))) ||
+      !match(ShlVal, m_OneUse(m_ZExt(m_Value(UpperSrc)))))
+    return nullptr;
+  if (*C != HalfWidth || LowerSrc->getType() != UpperSrc->getType() ||
+      LowerSrc->getType()->getScalarSizeInBits() != HalfWidth)
+    return nullptr;
+
+  // Find matching bswap instructions.
+  // TODO: Add more patterns (bitreverse?)
+  Value *LowerBSwap, *UpperBSwap;
+  if (!match(LowerSrc, m_BSwap(m_Value(LowerBSwap))) ||
+      !match(UpperSrc, m_BSwap(m_Value(UpperBSwap))))
+    return nullptr;
+
+  // Push the concat down, swapping the lower/upper sources.
+  Value *NewLower = Builder.CreateZExt(UpperBSwap, Ty);
+  Value *NewUpper = Builder.CreateZExt(LowerBSwap, Ty);
+  NewUpper = Builder.CreateShl(NewUpper, HalfWidth);
+  Value *BinOp = Builder.CreateOr(NewLower, NewUpper);
+  Function *F = Intrinsic::getDeclaration(Or.getModule(), Intrinsic::bswap, Ty);
+  return Builder.CreateCall(F, BinOp);
+}
+
 /// If all elements of two constant vectors are 0/-1 and inverses, return true.
 static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) {
   unsigned NumElts = cast<VectorType>(C1->getType())->getNumElements();
@@ -2532,6 +2575,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
   if (Instruction *Rotate = matchRotate(I))
     return Rotate;
 
+  if (Instruction *Concat = matchOrConcat(I, Builder))
+    return replaceInstUsesWith(I, Concat);
+
   Value *X, *Y;
   const APInt *CV;
   if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) &&

diff  --git a/llvm/test/Transforms/InstCombine/or-concat.ll b/llvm/test/Transforms/InstCombine/or-concat.ll
index f0d36f2a60e5..77cdaa9a37dd 100644
--- a/llvm/test/Transforms/InstCombine/or-concat.ll
+++ b/llvm/test/Transforms/InstCombine/or-concat.ll
@@ -13,16 +13,8 @@
 ; PR45715
 define i64 @concat_bswap32_unary_split(i64 %a0) {
 ; CHECK-LABEL: @concat_bswap32_unary_split(
-; CHECK-NEXT:    [[TMP1:%.*]] = lshr i64 [[A0:%.*]], 32
-; CHECK-NEXT:    [[TMP2:%.*]] = trunc i64 [[TMP1]] to i32
-; CHECK-NEXT:    [[TMP3:%.*]] = trunc i64 [[A0]] to i32
-; CHECK-NEXT:    [[TMP4:%.*]] = tail call i32 @llvm.bswap.i32(i32 [[TMP2]])
-; CHECK-NEXT:    [[TMP5:%.*]] = tail call i32 @llvm.bswap.i32(i32 [[TMP3]])
-; CHECK-NEXT:    [[TMP6:%.*]] = zext i32 [[TMP4]] to i64
-; CHECK-NEXT:    [[TMP7:%.*]] = zext i32 [[TMP5]] to i64
-; CHECK-NEXT:    [[TMP8:%.*]] = shl nuw i64 [[TMP7]], 32
-; CHECK-NEXT:    [[TMP9:%.*]] = or i64 [[TMP8]], [[TMP6]]
-; CHECK-NEXT:    ret i64 [[TMP9]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.bswap.i64(i64 [[A0:%.*]])
+; CHECK-NEXT:    ret i64 [[TMP1]]
 ;
   %1 = lshr i64 %a0, 32
   %2 = trunc i64 %1 to i32
@@ -39,15 +31,10 @@ define i64 @concat_bswap32_unary_split(i64 %a0) {
 define i64 @concat_bswap32_unary_flip(i64 %a0) {
 ; CHECK-LABEL: @concat_bswap32_unary_flip(
 ; CHECK-NEXT:    [[TMP1:%.*]] = lshr i64 [[A0:%.*]], 32
-; CHECK-NEXT:    [[TMP2:%.*]] = trunc i64 [[TMP1]] to i32
-; CHECK-NEXT:    [[TMP3:%.*]] = trunc i64 [[A0]] to i32
-; CHECK-NEXT:    [[TMP4:%.*]] = tail call i32 @llvm.bswap.i32(i32 [[TMP2]])
-; CHECK-NEXT:    [[TMP5:%.*]] = tail call i32 @llvm.bswap.i32(i32 [[TMP3]])
-; CHECK-NEXT:    [[TMP6:%.*]] = zext i32 [[TMP4]] to i64
-; CHECK-NEXT:    [[TMP7:%.*]] = zext i32 [[TMP5]] to i64
-; CHECK-NEXT:    [[TMP8:%.*]] = shl nuw i64 [[TMP6]], 32
-; CHECK-NEXT:    [[TMP9:%.*]] = or i64 [[TMP8]], [[TMP7]]
-; CHECK-NEXT:    ret i64 [[TMP9]]
+; CHECK-NEXT:    [[TMP2:%.*]] = shl i64 [[A0]], 32
+; CHECK-NEXT:    [[TMP3:%.*]] = or i64 [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.bswap.i64(i64 [[TMP3]])
+; CHECK-NEXT:    ret i64 [[TMP4]]
 ;
   %1 = lshr i64 %a0, 32
   %2 = trunc i64 %1 to i32
@@ -63,13 +50,12 @@ define i64 @concat_bswap32_unary_flip(i64 %a0) {
 
 define i64 @concat_bswap32_binary(i32 %a0, i32 %a1) {
 ; CHECK-LABEL: @concat_bswap32_binary(
-; CHECK-NEXT:    [[TMP1:%.*]] = tail call i32 @llvm.bswap.i32(i32 [[A0:%.*]])
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call i32 @llvm.bswap.i32(i32 [[A1:%.*]])
-; CHECK-NEXT:    [[TMP3:%.*]] = zext i32 [[TMP1]] to i64
-; CHECK-NEXT:    [[TMP4:%.*]] = zext i32 [[TMP2]] to i64
-; CHECK-NEXT:    [[TMP5:%.*]] = shl nuw i64 [[TMP4]], 32
-; CHECK-NEXT:    [[TMP6:%.*]] = or i64 [[TMP5]], [[TMP3]]
-; CHECK-NEXT:    ret i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i32 [[A1:%.*]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i32 [[A0:%.*]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = shl nuw i64 [[TMP2]], 32
+; CHECK-NEXT:    [[TMP4:%.*]] = or i64 [[TMP3]], [[TMP1]]
+; CHECK-NEXT:    [[TMP5:%.*]] = call i64 @llvm.bswap.i64(i64 [[TMP4]])
+; CHECK-NEXT:    ret i64 [[TMP5]]
 ;
   %1 = tail call i32 @llvm.bswap.i32(i32 %a0)
   %2 = tail call i32 @llvm.bswap.i32(i32 %a1)


        


More information about the llvm-commits mailing list