[llvm-commits] [llvm] r57095 - in /llvm/trunk: lib/Transforms/Scalar/InstructionCombining.cpp test/Transforms/InstCombine/bswap.ll

Chris Lattner sabre at nondot.org
Sat Oct 4 19:13:19 PDT 2008


Author: lattner
Date: Sat Oct  4 21:13:19 2008
New Revision: 57095

URL: http://llvm.org/viewvc/llvm-project?rev=57095&view=rev
Log:
rewrite bswap matching to be more general, allowing arbitrary
shifting and masking inside a bswap expr.  This allows it to handle
the cases from PR2842, which involve the intermediate 'or' 
expressions being shifted, not just the input value.


Modified:
    llvm/trunk/lib/Transforms/Scalar/InstructionCombining.cpp
    llvm/trunk/test/Transforms/InstCombine/bswap.ll

Modified: llvm/trunk/lib/Transforms/Scalar/InstructionCombining.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Scalar/InstructionCombining.cpp?rev=57095&r1=57094&r2=57095&view=diff

==============================================================================
--- llvm/trunk/lib/Transforms/Scalar/InstructionCombining.cpp (original)
+++ llvm/trunk/lib/Transforms/Scalar/InstructionCombining.cpp Sat Oct  4 21:13:19 2008
@@ -3887,88 +3887,130 @@
   return Changed ? &I : 0;
 }
 
-/// CollectBSwapParts - Look to see if the specified value defines a single byte
-/// in the result.  If it does, and if the specified byte hasn't been filled in
-/// yet, fill it in and return false.
-static bool CollectBSwapParts(Value *V, SmallVector<Value*, 8> &ByteValues) {
-  Instruction *I = dyn_cast<Instruction>(V);
-  if (I == 0) return true;
+/// CollectBSwapParts - Analyze the specified subexpression and see if it is
+/// capable of providing pieces of a bswap.  The subexpression provides pieces
+/// of a bswap if it is proven that each of the non-zero bytes in the output of
+/// the expression came from the corresponding "byte swapped" byte in some other
+/// value.  For example, if the current subexpression is "(shl i32 %X, 24)" then
+/// we know that the expression deposits the low byte of %X into the high byte
+/// of the bswap result and that all other bytes are zero.  This expression is
+/// accepted, the high byte of ByteValues is set to X to indicate a correct
+/// match.
+///
+/// This function returns true if the match was unsuccessful and false if so.
+/// On entry to the function the "OverallLeftShift" is a signed integer value
+/// indicating the number of bytes that the subexpression is later shifted.  For
+/// example, if the expression is later right shifted by 16 bits, the
+/// OverallLeftShift value would be -2 on entry.  This is used to specify which
+/// byte of ByteValues is actually being set.
+///
+/// Similarly, ByteMask is a bitmask where a bit is clear if its corresponding
+/// byte is masked to zero by a user.  For example, in (X & 255), X will be
+/// processed with a bytemask of 1.  Because bytemask is 32-bits, this limits
+/// this function to working on up to 32-byte (256 bit) values.  ByteMask is
+/// always in the local (OverallLeftShift) coordinate space.
+///
+static bool CollectBSwapParts(Value *V, int OverallLeftShift, uint32_t ByteMask,
+                              SmallVector<Value*, 8> &ByteValues) {
+  if (Instruction *I = dyn_cast<Instruction>(V)) {
+    // If this is an or instruction, it may be an inner node of the bswap.
+    if (I->getOpcode() == Instruction::Or) {
+      return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask,
+                               ByteValues) ||
+             CollectBSwapParts(I->getOperand(1), OverallLeftShift, ByteMask,
+                               ByteValues);
+    }
+  
+    // If this is a logical shift by a constant multiple of 8, recurse with
+    // OverallLeftShift and ByteMask adjusted.
+    if (I->isLogicalShift() && isa<ConstantInt>(I->getOperand(1))) {
+      unsigned ShAmt = 
+        cast<ConstantInt>(I->getOperand(1))->getLimitedValue(~0U);
+      // Ensure the shift amount is defined and of a byte value.
+      if ((ShAmt & 7) || (ShAmt > 8*ByteValues.size()))
+        return true;
 
-  // If this is an or instruction, it is an inner node of the bswap.
-  if (I->getOpcode() == Instruction::Or)
-    return CollectBSwapParts(I->getOperand(0), ByteValues) ||
-           CollectBSwapParts(I->getOperand(1), ByteValues);
-  
-  uint32_t BitWidth = I->getType()->getPrimitiveSizeInBits();
-  // If this is a shift by a constant int, and it is "24", then its operand
-  // defines a byte.  We only handle unsigned types here.
-  if (I->isShift() && isa<ConstantInt>(I->getOperand(1))) {
-    // Not shifting the entire input by N-1 bytes?
-    if (cast<ConstantInt>(I->getOperand(1))->getLimitedValue(BitWidth) !=
-        8*(ByteValues.size()-1))
-      return true;
-    
-    unsigned DestNo;
-    if (I->getOpcode() == Instruction::Shl) {
-      // X << 24 defines the top byte with the lowest of the input bytes.
-      DestNo = ByteValues.size()-1;
-    } else if (I->getOpcode() == Instruction::LShr) {
-      // X >>u 24 defines the low byte with the highest of the input bytes.
-      DestNo = 0;
-    } else {
-      // Arithmetic shift right may have the top bits set.
+      unsigned ByteShift = ShAmt >> 3;
+      if (I->getOpcode() == Instruction::Shl) {
+        // X << 2 -> collect(X, +2)
+        OverallLeftShift += ByteShift;
+        ByteMask >>= ByteShift;
+      } else {
+        // X >>u 2 -> collect(X, -2)
+        OverallLeftShift -= ByteShift;
+        ByteMask <<= ByteShift;
+        ByteMask &= (~0U >> 32-ByteValues.size());
+      }
+
+      if (OverallLeftShift >= (int)ByteValues.size()) return true;
+      if (OverallLeftShift <= -(int)ByteValues.size()) return true;
+
+      return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask, 
+                               ByteValues);
+    }
+
+    // If this is a logical 'and' with a mask that clears bytes, clear the
+    // corresponding bytes in ByteMask.
+    if (I->getOpcode() == Instruction::And &&
+        isa<ConstantInt>(I->getOperand(1))) {
+      // Scan every byte of the and mask, seeing if the byte is either 0 or 255.
+      unsigned NumBytes = ByteValues.size();
+      APInt Byte(I->getType()->getPrimitiveSizeInBits(), 255);
+      const APInt &AndMask = cast<ConstantInt>(I->getOperand(1))->getValue();
+      
+      for (unsigned i = 0; i != NumBytes; ++i, Byte <<= 8) {
+        // If this byte is masked out by a later operation, we don't care what
+        // the and mask is.
+        if ((ByteMask & (1 << i)) == 0)
+          continue;
+        
+        // If the AndMask is all zeros for this byte, clear the bit.
+        APInt MaskB = AndMask & Byte;
+        if (MaskB == 0) {
+          ByteMask &= ~(1U << i);
+          continue;
+        }
+        
+        // If the AndMask is not all ones for this byte, it's not a bytezap.
+        if (MaskB != Byte)
+          return true;
+
+        // Otherwise, this byte is kept.
+      }
+
+      return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask, 
+                               ByteValues);
+    }
+  }
+  
+  // Okay, we got to something that isn't a shift, 'or' or 'and'.  This must be
+  // the input value to the bswap.  Some observations: 1) if more than one byte
+  // is demanded from this input, then it could not be successfully assembled
+  // into a byteswap.  At least one of the two bytes would not be aligned with
+  // their ultimate destination.
+  if (!isPowerOf2_32(ByteMask)) return true;
+  unsigned InputByteNo = CountTrailingZeros_32(ByteMask);
+  
+  // 2) The input and ultimate destinations must line up: if byte 3 of an i32
+  // is demanded, it needs to go into byte 0 of the result.  This means that the
+  // byte needs to be shifted until it lands in the right byte bucket.  The
+  // shift amount depends on the position: if the byte is coming from the high
+  // part of the value (e.g. byte 3) then it must be shifted right.  If from the
+  // low part, it must be shifted left.
+  unsigned DestByteNo = InputByteNo + OverallLeftShift;
+  if (InputByteNo < ByteValues.size()/2) {
+    if (ByteValues.size()-1-DestByteNo != InputByteNo)
       return true;
-    }
-    
-    // If the destination byte value is already defined, the values are or'd
-    // together, which isn't a bswap (unless it's an or of the same bits).
-    if (ByteValues[DestNo] && ByteValues[DestNo] != I->getOperand(0))
+  } else {
+    if (ByteValues.size()-1-DestByteNo != InputByteNo)
       return true;
-    ByteValues[DestNo] = I->getOperand(0);
-    return false;
   }
   
-  // Otherwise, we can only handle and(shift X, imm), imm).  Bail out of if we
-  // don't have this.
-  Value *Shift = 0, *ShiftLHS = 0;
-  ConstantInt *AndAmt = 0, *ShiftAmt = 0;
-  if (!match(I, m_And(m_Value(Shift), m_ConstantInt(AndAmt))) ||
-      !match(Shift, m_Shift(m_Value(ShiftLHS), m_ConstantInt(ShiftAmt))))
-    return true;
-  Instruction *SI = cast<Instruction>(Shift);
-
-  // Make sure that the shift amount is by a multiple of 8 and isn't too big.
-  if (ShiftAmt->getLimitedValue(BitWidth) & 7 ||
-      ShiftAmt->getLimitedValue(BitWidth) > 8*ByteValues.size())
-    return true;
-  
-  // Turn 0xFF -> 0, 0xFF00 -> 1, 0xFF0000 -> 2, etc.
-  unsigned DestByte;
-  if (AndAmt->getValue().getActiveBits() > 64)
-    return true;
-  uint64_t AndAmtVal = AndAmt->getZExtValue();
-  for (DestByte = 0; DestByte != ByteValues.size(); ++DestByte)
-    if (AndAmtVal == uint64_t(0xFF) << 8*DestByte)
-      break;
-  // Unknown mask for bswap.
-  if (DestByte == ByteValues.size()) return true;
-  
-  unsigned ShiftBytes = ShiftAmt->getZExtValue()/8;
-  unsigned SrcByte;
-  if (SI->getOpcode() == Instruction::Shl)
-    SrcByte = DestByte - ShiftBytes;
-  else
-    SrcByte = DestByte + ShiftBytes;
-  
-  // If the SrcByte isn't a bswapped value from the DestByte, reject it.
-  if (SrcByte != ByteValues.size()-DestByte-1)
-    return true;
-  
   // If the destination byte value is already defined, the values are or'd
   // together, which isn't a bswap (unless it's an or of the same bits).
-  if (ByteValues[DestByte] && ByteValues[DestByte] != SI->getOperand(0))
+  if (ByteValues[DestByteNo] && ByteValues[DestByteNo] != V)
     return true;
-  ByteValues[DestByte] = SI->getOperand(0);
+  ByteValues[DestByteNo] = V;
   return false;
 }
 
@@ -3976,7 +4018,9 @@
 /// If so, insert the new bswap intrinsic and return it.
 Instruction *InstCombiner::MatchBSwap(BinaryOperator &I) {
   const IntegerType *ITy = dyn_cast<IntegerType>(I.getType());
-  if (!ITy || ITy->getBitWidth() % 16) 
+  if (!ITy || ITy->getBitWidth() % 16 || 
+      // ByteMask only allows up to 32-byte values.
+      ITy->getBitWidth() > 32*8) 
     return 0;   // Can only bswap pairs of bytes.  Can't do vectors.
   
   /// ByteValues - For each byte of the result, we keep track of which value
@@ -3985,8 +4029,8 @@
   ByteValues.resize(ITy->getBitWidth()/8);
     
   // Try to find all the pieces corresponding to the bswap.
-  if (CollectBSwapParts(I.getOperand(0), ByteValues) ||
-      CollectBSwapParts(I.getOperand(1), ByteValues))
+  uint32_t ByteMask = ~0U >> (32-ByteValues.size());
+  if (CollectBSwapParts(&I, 0, ByteMask, ByteValues))
     return 0;
   
   // Check to see if all of the bytes come from the same value.

Modified: llvm/trunk/test/Transforms/InstCombine/bswap.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/bswap.ll?rev=57095&r1=57094&r2=57095&view=diff

==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/bswap.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/bswap.ll Sat Oct  4 21:13:19 2008
@@ -1,5 +1,5 @@
 ; RUN: llvm-as < %s | opt -instcombine | llvm-dis | \
-; RUN:    grep {call.*llvm.bswap} | count 5
+; RUN:    grep {call.*llvm.bswap} | count 6
 
 define i32 @test1(i32 %i) {
 	%tmp1 = lshr i32 %i, 24		; <i32> [#uses=1]
@@ -55,3 +55,18 @@
 	%retval = trunc i32 %tmp6.upgrd.4 to i16		; <i16> [#uses=1]
 	ret i16 %retval
 }
+
+; PR2842
+define i32 @test6(i32 %x) nounwind readnone {
+	%tmp = shl i32 %x, 16		; <i32> [#uses=1]
+	%x.mask = and i32 %x, 65280		; <i32> [#uses=1]
+	%tmp1 = lshr i32 %x, 16		; <i32> [#uses=1]
+	%tmp2 = and i32 %tmp1, 255		; <i32> [#uses=1]
+	%tmp3 = or i32 %x.mask, %tmp		; <i32> [#uses=1]
+	%tmp4 = or i32 %tmp3, %tmp2		; <i32> [#uses=1]
+	%tmp5 = shl i32 %tmp4, 8		; <i32> [#uses=1]
+	%tmp6 = lshr i32 %x, 24		; <i32> [#uses=1]
+	%tmp7 = or i32 %tmp5, %tmp6		; <i32> [#uses=1]
+	ret i32 %tmp7
+}
+





More information about the llvm-commits mailing list