[llvm] a9d6b0e - [InstCombine] Fix big-endian miscompile of (bitcast (zext/trunc (bitcast)))

Bjorn Pettersson via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 2 02:06:25 PST 2019


Author: Bjorn Pettersson
Date: 2019-12-02T11:05:25+01:00
New Revision: a9d6b0e5444741d08ff1df7cf71d1559e7fefc1f

URL: https://github.com/llvm/llvm-project/commit/a9d6b0e5444741d08ff1df7cf71d1559e7fefc1f
DIFF: https://github.com/llvm/llvm-project/commit/a9d6b0e5444741d08ff1df7cf71d1559e7fefc1f.diff

LOG: [InstCombine] Fix big-endian miscompile of (bitcast (zext/trunc (bitcast)))

Summary:
optimizeVectorResize is rewriting patterns like:
  %1 = bitcast vector %src to integer
  %2 = trunc/zext %1
  %dst = bitcast %2 to vector

Since bitcasting between integer an vector types gives
different integer values depending on endianness, we need
to take endianness into account. As it happens the old
implementation only produced the correct result for little
endian targets.

Fixes: https://bugs.llvm.org/show_bug.cgi?id=44178

Reviewers: spatel, lattner, lebedev.ri

Reviewed By: spatel, lebedev.ri

Subscribers: lebedev.ri, hiraditya, uabelho, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D70844

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
    llvm/test/Transforms/InstCombine/cast.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 0390368c4bb4..078a80de2df4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -18,6 +18,7 @@
 #include "llvm/IR/DIBuilder.h"
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/Support/KnownBits.h"
+#include <numeric>
 using namespace llvm;
 using namespace PatternMatch;
 
@@ -1820,12 +1821,24 @@ Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) {
 }
 
 /// This input value (which is known to have vector type) is being zero extended
-/// or truncated to the specified vector type.
+/// or truncated to the specified vector type. Since the zext/trunc is done
+/// using an integer type, we have a (bitcast(cast(bitcast))) pattern,
+/// endianness will impact which end of the vector that is extended or
+/// truncated.
+///
+/// A vector is always stored with index 0 at the lowest address, which
+/// corresponds to the most significant bits for a big endian stored integer and
+/// the least significant bits for little endian. A trunc/zext of an integer
+/// impacts the big end of the integer. Thus, we need to add/remove elements at
+/// the front of the vector for big endian targets, and the back of the vector
+/// for little endian targets.
+///
 /// Try to replace it with a shuffle (and vector/vector bitcast) if possible.
 ///
 /// The source and destination vector types may have 
diff erent element types.
-static Instruction *optimizeVectorResize(Value *InVal, VectorType *DestTy,
-                                         InstCombiner &IC) {
+static Instruction *optimizeVectorResizeWithIntegerBitCasts(Value *InVal,
+                                                            VectorType *DestTy,
+                                                            InstCombiner &IC) {
   // We can only do this optimization if the output is a multiple of the input
   // element size, or the input is a multiple of the output element size.
   // Convert the input type to have the same element type as the output.
@@ -1844,31 +1857,53 @@ static Instruction *optimizeVectorResize(Value *InVal, VectorType *DestTy,
     InVal = IC.Builder.CreateBitCast(InVal, SrcTy);
   }
 
+  bool IsBigEndian = IC.getDataLayout().isBigEndian();
+  unsigned SrcElts = SrcTy->getNumElements();
+  unsigned DestElts = DestTy->getNumElements();
+
+  assert(SrcElts != DestElts && "Element counts should be 
diff erent.");
+
   // Now that the element types match, get the shuffle mask and RHS of the
   // shuffle to use, which depends on whether we're increasing or decreasing the
   // size of the input.
-  SmallVector<uint32_t, 16> ShuffleMask;
+  SmallVector<uint32_t, 16> ShuffleMaskStorage;
+  ArrayRef<uint32_t> ShuffleMask;
   Value *V2;
 
-  if (SrcTy->getNumElements() > DestTy->getNumElements()) {
-    // If we're shrinking the number of elements, just shuffle in the low
-    // elements from the input and use undef as the second shuffle input.
-    V2 = UndefValue::get(SrcTy);
-    for (unsigned i = 0, e = DestTy->getNumElements(); i != e; ++i)
-      ShuffleMask.push_back(i);
+  // Produce an identify shuffle mask for the src vector.
+  ShuffleMaskStorage.resize(SrcElts);
+  std::iota(ShuffleMaskStorage.begin(), ShuffleMaskStorage.end(), 0);
 
+  if (SrcElts > DestElts) {
+    // If we're shrinking the number of elements (rewriting an integer
+    // truncate), just shuffle in the elements corresponding to the least
+    // significant bits from the input and use undef as the second shuffle
+    // input.
+    V2 = UndefValue::get(SrcTy);
+    // Make sure the shuffle mask selects the "least significant bits" by
+    // keeping elements from back of the src vector for big endian, and from the
+    // front for little endian.
+    ShuffleMask = ShuffleMaskStorage;
+    if (IsBigEndian)
+      ShuffleMask = ShuffleMask.take_back(DestElts);
+    else
+      ShuffleMask = ShuffleMask.take_front(DestElts);
   } else {
-    // If we're increasing the number of elements, shuffle in all of the
-    // elements from InVal and fill the rest of the result elements with zeros
-    // from a constant zero.
+    // If we're increasing the number of elements (rewriting an integer zext),
+    // shuffle in all of the elements from InVal. Fill the rest of the result
+    // elements with zeros from a constant zero.
     V2 = Constant::getNullValue(SrcTy);
-    unsigned SrcElts = SrcTy->getNumElements();
-    for (unsigned i = 0, e = SrcElts; i != e; ++i)
-      ShuffleMask.push_back(i);
-
-    // The excess elements reference the first element of the zero input.
-    for (unsigned i = 0, e = DestTy->getNumElements()-SrcElts; i != e; ++i)
-      ShuffleMask.push_back(SrcElts);
+    // Use first elt from V2 when indicating zero in the shuffle mask.
+    uint32_t NullElt = SrcElts;
+    // Extend with null values in the "most significant bits" by adding elements
+    // in front of the src vector for big endian, and at the back for little
+    // endian.
+    unsigned DeltaElts = DestElts - SrcElts;
+    if (IsBigEndian)
+      ShuffleMaskStorage.insert(ShuffleMaskStorage.begin(), DeltaElts, NullElt);
+    else
+      ShuffleMaskStorage.append(DeltaElts, NullElt);
+    ShuffleMask = ShuffleMaskStorage;
   }
 
   return new ShuffleVectorInst(InVal, V2,
@@ -2375,8 +2410,8 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
         CastInst *SrcCast = cast<CastInst>(Src);
         if (BitCastInst *BCIn = dyn_cast<BitCastInst>(SrcCast->getOperand(0)))
           if (isa<VectorType>(BCIn->getOperand(0)->getType()))
-            if (Instruction *I = optimizeVectorResize(BCIn->getOperand(0),
-                                               cast<VectorType>(DestTy), *this))
+            if (Instruction *I = optimizeVectorResizeWithIntegerBitCasts(
+                    BCIn->getOperand(0), cast<VectorType>(DestTy), *this))
               return I;
       }
 

diff  --git a/llvm/test/Transforms/InstCombine/cast.ll b/llvm/test/Transforms/InstCombine/cast.ll
index 66eb3904ebb7..d85286b46029 100644
--- a/llvm/test/Transforms/InstCombine/cast.ll
+++ b/llvm/test/Transforms/InstCombine/cast.ll
@@ -823,9 +823,13 @@ define i64 @test59(i8 %A, i8 %B) {
 }
 
 define <3 x i32> @test60(<4 x i32> %call4) {
-; ALL-LABEL: @test60(
-; ALL-NEXT:    [[P10:%.*]] = shufflevector <4 x i32> [[CALL4:%.*]], <4 x i32> undef, <3 x i32> <i32 0, i32 1, i32 2>
-; ALL-NEXT:    ret <3 x i32> [[P10]]
+; BE-LABEL: @test60(
+; BE-NEXT:    [[P10:%.*]] = shufflevector <4 x i32> [[CALL4:%.*]], <4 x i32> undef, <3 x i32> <i32 1, i32 2, i32 3>
+; BE-NEXT:    ret <3 x i32> [[P10]]
+;
+; LE-LABEL: @test60(
+; LE-NEXT:    [[P10:%.*]] = shufflevector <4 x i32> [[CALL4:%.*]], <4 x i32> undef, <3 x i32> <i32 0, i32 1, i32 2>
+; LE-NEXT:    ret <3 x i32> [[P10]]
 ;
   %p11 = bitcast <4 x i32> %call4 to i128
   %p9 = trunc i128 %p11 to i96
@@ -835,9 +839,13 @@ define <3 x i32> @test60(<4 x i32> %call4) {
 }
 
 define <4 x i32> @test61(<3 x i32> %call4) {
-; ALL-LABEL: @test61(
-; ALL-NEXT:    [[P10:%.*]] = shufflevector <3 x i32> [[CALL4:%.*]], <3 x i32> <i32 0, i32 undef, i32 undef>, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; ALL-NEXT:    ret <4 x i32> [[P10]]
+; BE-LABEL: @test61(
+; BE-NEXT:    [[P10:%.*]] = shufflevector <3 x i32> [[CALL4:%.*]], <3 x i32> <i32 0, i32 undef, i32 undef>, <4 x i32> <i32 3, i32 0, i32 1, i32 2>
+; BE-NEXT:    ret <4 x i32> [[P10]]
+;
+; LE-LABEL: @test61(
+; LE-NEXT:    [[P10:%.*]] = shufflevector <3 x i32> [[CALL4:%.*]], <3 x i32> <i32 0, i32 undef, i32 undef>, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; LE-NEXT:    ret <4 x i32> [[P10]]
 ;
   %p11 = bitcast <3 x i32> %call4 to i96
   %p9 = zext i96 %p11 to i128
@@ -846,10 +854,15 @@ define <4 x i32> @test61(<3 x i32> %call4) {
 }
 
 define <4 x i32> @test62(<3 x float> %call4) {
-; ALL-LABEL: @test62(
-; ALL-NEXT:    [[TMP1:%.*]] = bitcast <3 x float> [[CALL4:%.*]] to <3 x i32>
-; ALL-NEXT:    [[P10:%.*]] = shufflevector <3 x i32> [[TMP1]], <3 x i32> <i32 0, i32 undef, i32 undef>, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; ALL-NEXT:    ret <4 x i32> [[P10]]
+; BE-LABEL: @test62(
+; BE-NEXT:    [[TMP1:%.*]] = bitcast <3 x float> [[CALL4:%.*]] to <3 x i32>
+; BE-NEXT:    [[P10:%.*]] = shufflevector <3 x i32> [[TMP1]], <3 x i32> <i32 0, i32 undef, i32 undef>, <4 x i32> <i32 3, i32 0, i32 1, i32 2>
+; BE-NEXT:    ret <4 x i32> [[P10]]
+;
+; LE-LABEL: @test62(
+; LE-NEXT:    [[TMP1:%.*]] = bitcast <3 x float> [[CALL4:%.*]] to <3 x i32>
+; LE-NEXT:    [[P10:%.*]] = shufflevector <3 x i32> [[TMP1]], <3 x i32> <i32 0, i32 undef, i32 undef>, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; LE-NEXT:    ret <4 x i32> [[P10]]
 ;
   %p11 = bitcast <3 x float> %call4 to i96
   %p9 = zext i96 %p11 to i128


        


More information about the llvm-commits mailing list