[llvm] 9c3138b - [InstCombine] visitTrunc - pass through undefs for trunc(shift(trunc/ext(x),c)) patterns

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 13 07:00:22 PDT 2020


Author: Simon Pilgrim
Date: 2020-10-13T14:35:18+01:00
New Revision: 9c3138bd6d8b3e303f0f711753506b330ffa8df0

URL: https://github.com/llvm/llvm-project/commit/9c3138bd6d8b3e303f0f711753506b330ffa8df0
DIFF: https://github.com/llvm/llvm-project/commit/9c3138bd6d8b3e303f0f711753506b330ffa8df0.diff

LOG: [InstCombine] visitTrunc - pass through undefs for trunc(shift(trunc/ext(x),c)) patterns

Based on the recent patches D88475 and D88429 where we are losing undef values due to extension/comparisons.

I've added a Constant::mergeUndefsWith method that merges the undef scalar/elements from another Constant into a specific Constant.

Differential Revision: https://reviews.llvm.org/D88687

Added: 
    

Modified: 
    llvm/include/llvm/IR/Constant.h
    llvm/lib/IR/Constants.cpp
    llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
    llvm/test/Transforms/InstCombine/cast.ll
    llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Constant.h b/llvm/include/llvm/IR/Constant.h
index f4cdef2af774..97650c2051ca 100644
--- a/llvm/include/llvm/IR/Constant.h
+++ b/llvm/include/llvm/IR/Constant.h
@@ -204,6 +204,12 @@ class Constant : public User {
   /// Try to replace undefined constant C or undefined elements in C with
   /// Replacement. If no changes are made, the constant C is returned.
   static Constant *replaceUndefsWith(Constant *C, Constant *Replacement);
+
+  /// Merges undefs of a Constant with another Constant, along with the
+  /// undefs already present. Other doesn't have to be the same type as C, but
+  /// both must either be scalars or vectors with the same element count. If no
+  /// changes are made, the constant C is returned.
+  static Constant *mergeUndefsWith(Constant *C, Constant *Other);
 };
 
 } // end namespace llvm

diff  --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 9f83861f2aa6..7eca7dddf4a4 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -737,6 +737,40 @@ Constant *Constant::replaceUndefsWith(Constant *C, Constant *Replacement) {
   return ConstantVector::get(NewC);
 }
 
+Constant *Constant::mergeUndefsWith(Constant *C, Constant *Other) {
+  assert(C && Other && "Expected non-nullptr constant arguments");
+  if (match(C, m_Undef()))
+    return C;
+
+  Type *Ty = C->getType();
+  if (match(Other, m_Undef()))
+    return UndefValue::get(Ty);
+
+  auto *VTy = dyn_cast<FixedVectorType>(Ty);
+  if (!VTy)
+    return C;
+
+  Type *EltTy = VTy->getElementType();
+  unsigned NumElts = VTy->getNumElements();
+  assert(isa<FixedVectorType>(Other->getType()) &&
+         cast<FixedVectorType>(Other->getType())->getNumElements() == NumElts &&
+         "Type mismatch");
+
+  bool FoundExtraUndef = false;
+  SmallVector<Constant *, 32> NewC(NumElts);
+  for (unsigned I = 0; I != NumElts; ++I) {
+    NewC[I] = C->getAggregateElement(I);
+    Constant *OtherEltC = Other->getAggregateElement(I);
+    assert(NewC[I] && OtherEltC && "Unknown vector element");
+    if (!match(NewC[I], m_Undef()) && match(OtherEltC, m_Undef())) {
+      NewC[I] = UndefValue::get(EltTy);
+      FoundExtraUndef = true;
+    }
+  }
+  if (FoundExtraUndef)
+    return ConstantVector::get(NewC);
+  return C;
+}
 
 //===----------------------------------------------------------------------===//
 //                                ConstantInt

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index e259b898351d..478032f56bdf 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -810,8 +810,6 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
 
     // If the shift is small enough, all zero bits created by the shift are
     // removed by the trunc.
-    // TODO: Support passing through undef shift amounts - these currently get
-    // clamped to MaxAmt.
     if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE,
                                     APInt(SrcWidth, MaxShiftAmt)))) {
       // trunc (lshr (sext A), C) --> ashr A, C
@@ -819,6 +817,7 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
         Constant *MaxAmt = ConstantInt::get(SrcTy, DestWidth - 1, false);
         Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt);
         ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType());
+        ShAmt = Constant::mergeUndefsWith(ShAmt, C);
         return IsExact ? BinaryOperator::CreateExactAShr(A, ShAmt)
                        : BinaryOperator::CreateAShr(A, ShAmt);
       }
@@ -841,13 +840,12 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
 
     // If the shift is small enough, all zero/sign bits created by the shift are
     // removed by the trunc.
-    // TODO: Support passing through undef shift amounts - these currently get
-    // zero'd by getIntegerCast.
     if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE,
                                     APInt(SrcWidth, MaxShiftAmt)))) {
       auto *OldShift = cast<Instruction>(Src);
-      auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true);
       bool IsExact = OldShift->isExact();
+      auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true);
+      ShAmt = Constant::mergeUndefsWith(ShAmt, C);
       Value *Shift =
           OldShift->getOpcode() == Instruction::AShr
               ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact)

diff  --git a/llvm/test/Transforms/InstCombine/cast.ll b/llvm/test/Transforms/InstCombine/cast.ll
index c5f18b4c625e..c7217b9f4dd3 100644
--- a/llvm/test/Transforms/InstCombine/cast.ll
+++ b/llvm/test/Transforms/InstCombine/cast.ll
@@ -1570,7 +1570,7 @@ define <2 x i8> @trunc_lshr_sext_uniform(<2 x i8> %A) {
 
 define <2 x i8> @trunc_lshr_sext_uniform_undef(<2 x i8> %A) {
 ; ALL-LABEL: @trunc_lshr_sext_uniform_undef(
-; ALL-NEXT:    [[D:%.*]] = ashr <2 x i8> [[A:%.*]], <i8 6, i8 7>
+; ALL-NEXT:    [[D:%.*]] = ashr <2 x i8> [[A:%.*]], <i8 6, i8 undef>
 ; ALL-NEXT:    ret <2 x i8> [[D]]
 ;
   %B = sext <2 x i8> %A to <2 x i32>
@@ -1592,7 +1592,7 @@ define <2 x i8> @trunc_lshr_sext_nonuniform(<2 x i8> %A) {
 
 define <3 x i8> @trunc_lshr_sext_nonuniform_undef(<3 x i8> %A) {
 ; ALL-LABEL: @trunc_lshr_sext_nonuniform_undef(
-; ALL-NEXT:    [[D:%.*]] = ashr <3 x i8> [[A:%.*]], <i8 6, i8 2, i8 7>
+; ALL-NEXT:    [[D:%.*]] = ashr <3 x i8> [[A:%.*]], <i8 6, i8 2, i8 undef>
 ; ALL-NEXT:    ret <3 x i8> [[D]]
 ;
   %B = sext <3 x i8> %A to <3 x i32>

diff  --git a/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll b/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll
index 7a4a9c189727..269b4619974b 100644
--- a/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll
+++ b/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll
@@ -45,7 +45,7 @@ define <2 x i8> @trunc_lshr_trunc_nonuniform(<2 x i64> %a) {
 
 define <2 x i8> @trunc_lshr_trunc_uniform_undef(<2 x i64> %a) {
 ; CHECK-LABEL: @trunc_lshr_trunc_uniform_undef(
-; CHECK-NEXT:    [[C1:%.*]] = lshr <2 x i64> [[A:%.*]], <i64 24, i64 0>
+; CHECK-NEXT:    [[C1:%.*]] = lshr <2 x i64> [[A:%.*]], <i64 24, i64 undef>
 ; CHECK-NEXT:    [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8>
 ; CHECK-NEXT:    ret <2 x i8> [[D]]
 ;
@@ -131,7 +131,7 @@ define <2 x i8> @trunc_ashr_trunc_nonuniform(<2 x i64> %a) {
 
 define <2 x i8> @trunc_ashr_trunc_uniform_undef(<2 x i64> %a) {
 ; CHECK-LABEL: @trunc_ashr_trunc_uniform_undef(
-; CHECK-NEXT:    [[C1:%.*]] = ashr <2 x i64> [[A:%.*]], <i64 8, i64 0>
+; CHECK-NEXT:    [[C1:%.*]] = ashr <2 x i64> [[A:%.*]], <i64 8, i64 undef>
 ; CHECK-NEXT:    [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8>
 ; CHECK-NEXT:    ret <2 x i8> [[D]]
 ;


        


More information about the llvm-commits mailing list