[llvm] 14ff38e - [InstCombine] visitTrunc - trunc (lshr (sext A), C) --> (ashr A, C) non-uniform support

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 29 07:01:27 PDT 2020


Author: Simon Pilgrim
Date: 2020-09-29T15:01:16+01:00
New Revision: 14ff38e235c4aec8e444d8aec26ce5d3a4c524d2

URL: https://github.com/llvm/llvm-project/commit/14ff38e235c4aec8e444d8aec26ce5d3a4c524d2
DIFF: https://github.com/llvm/llvm-project/commit/14ff38e235c4aec8e444d8aec26ce5d3a4c524d2.diff

LOG: [InstCombine] visitTrunc - trunc (lshr (sext A), C) --> (ashr A, C) non-uniform support

This came from @lebedev.ri's suggestion to use m_SpecificInt_ICMP for D88429 - since I was going to change the m_APInt to m_Constant for that patch I thought I would do it for the only other user of the APInt first.

I've added a ConstantExpr::getUMin helper - its trivial to add UMAX/SMIN/SMAX but thought I'd wait until we have use cases.

Differential Revision: https://reviews.llvm.org/D88475

Added: 
    

Modified: 
    llvm/include/llvm/IR/Constants.h
    llvm/lib/IR/Constants.cpp
    llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
    llvm/test/Transforms/InstCombine/cast.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 8e2dba9b2417..6763d04a53e9 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -959,6 +959,7 @@ class ConstantExpr : public Constant {
   static Constant *getAnd(Constant *C1, Constant *C2);
   static Constant *getOr(Constant *C1, Constant *C2);
   static Constant *getXor(Constant *C1, Constant *C2);
+  static Constant *getUMin(Constant *C1, Constant *C2);
   static Constant *getShl(Constant *C1, Constant *C2,
                           bool HasNUW = false, bool HasNSW = false);
   static Constant *getLShr(Constant *C1, Constant *C2, bool isExact = false);

diff  --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index d84c7bc2da9d..83745b07cdd5 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -2560,6 +2560,11 @@ Constant *ConstantExpr::getXor(Constant *C1, Constant *C2) {
   return get(Instruction::Xor, C1, C2);
 }
 
+Constant *ConstantExpr::getUMin(Constant *C1, Constant *C2) {
+  Constant *Cmp = ConstantExpr::getICmp(CmpInst::ICMP_ULT, C1, C2);
+  return getSelect(Cmp, C1, C2);
+}
+
 Constant *ConstantExpr::getShl(Constant *C1, Constant *C2,
                                bool HasNUW, bool HasNSW) {
   unsigned Flags = (HasNUW ? OverflowingBinaryOperator::NoUnsignedWrap : 0) |

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 5982d48e6bf6..ca55c8f5a887 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -827,23 +827,30 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
     return CastInst::CreateIntegerCast(Shift, DestTy, false);
   }
 
-  const APInt *C;
-  if (match(Src, m_LShr(m_SExt(m_Value(A)), m_APInt(C)))) {
+  Constant *C;
+  if (match(Src, m_LShr(m_SExt(m_Value(A)), m_Constant(C)))) {
     unsigned AWidth = A->getType()->getScalarSizeInBits();
     unsigned MaxShiftAmt = SrcWidth - std::max(DestWidth, AWidth);
 
     // If the shift is small enough, all zero bits created by the shift are
     // removed by the trunc.
-    if (C->getZExtValue() <= MaxShiftAmt) {
+    // TODO: Support passing through undef shift amounts - these currently get
+    // clamped to MaxAmt.
+    if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE,
+                                    APInt(SrcWidth, MaxShiftAmt)))) {
       // trunc (lshr (sext A), C) --> ashr A, C
       if (A->getType() == DestTy) {
-        unsigned ShAmt = std::min((unsigned)C->getZExtValue(), DestWidth - 1);
-        return BinaryOperator::CreateAShr(A, ConstantInt::get(DestTy, ShAmt));
+        Constant *MaxAmt = ConstantInt::get(SrcTy, DestWidth - 1, false);
+        Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt);
+        ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType());
+        return BinaryOperator::CreateAShr(A, ShAmt);
       }
       // The types are mismatched, so create a cast after shifting:
       // trunc (lshr (sext A), C) --> sext/trunc (ashr A, C)
       if (Src->hasOneUse()) {
-        unsigned ShAmt = std::min((unsigned)C->getZExtValue(), AWidth - 1);
+        Constant *MaxAmt = ConstantInt::get(SrcTy, AWidth - 1, false);
+        Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt);
+        ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType());
         Value *Shift = Builder.CreateAShr(A, ShAmt);
         return CastInst::CreateIntegerCast(Shift, DestTy, true);
       }

diff  --git a/llvm/test/Transforms/InstCombine/cast.ll b/llvm/test/Transforms/InstCombine/cast.ll
index ad6d22aa06e4..1d3d006ad238 100644
--- a/llvm/test/Transforms/InstCombine/cast.ll
+++ b/llvm/test/Transforms/InstCombine/cast.ll
@@ -1559,9 +1559,7 @@ define <2 x i8> @trunc_lshr_sext_uniform(<2 x i8> %A) {
 
 define <2 x i8> @trunc_lshr_sext_uniform_undef(<2 x i8> %A) {
 ; ALL-LABEL: @trunc_lshr_sext_uniform_undef(
-; ALL-NEXT:    [[B:%.*]] = sext <2 x i8> [[A:%.*]] to <2 x i32>
-; ALL-NEXT:    [[C:%.*]] = lshr <2 x i32> [[B]], <i32 6, i32 undef>
-; ALL-NEXT:    [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8>
+; ALL-NEXT:    [[D:%.*]] = ashr <2 x i8> [[A:%.*]], <i8 6, i8 7>
 ; ALL-NEXT:    ret <2 x i8> [[D]]
 ;
   %B = sext <2 x i8> %A to <2 x i32>
@@ -1572,9 +1570,7 @@ define <2 x i8> @trunc_lshr_sext_uniform_undef(<2 x i8> %A) {
 
 define <2 x i8> @trunc_lshr_sext_nonuniform(<2 x i8> %A) {
 ; ALL-LABEL: @trunc_lshr_sext_nonuniform(
-; ALL-NEXT:    [[B:%.*]] = sext <2 x i8> [[A:%.*]] to <2 x i32>
-; ALL-NEXT:    [[C:%.*]] = lshr <2 x i32> [[B]], <i32 6, i32 2>
-; ALL-NEXT:    [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8>
+; ALL-NEXT:    [[D:%.*]] = ashr <2 x i8> [[A:%.*]], <i8 6, i8 2>
 ; ALL-NEXT:    ret <2 x i8> [[D]]
 ;
   %B = sext <2 x i8> %A to <2 x i32>
@@ -1585,9 +1581,7 @@ define <2 x i8> @trunc_lshr_sext_nonuniform(<2 x i8> %A) {
 
 define <3 x i8> @trunc_lshr_sext_nonuniform_undef(<3 x i8> %A) {
 ; ALL-LABEL: @trunc_lshr_sext_nonuniform_undef(
-; ALL-NEXT:    [[B:%.*]] = sext <3 x i8> [[A:%.*]] to <3 x i32>
-; ALL-NEXT:    [[C:%.*]] = lshr <3 x i32> [[B]], <i32 6, i32 2, i32 undef>
-; ALL-NEXT:    [[D:%.*]] = trunc <3 x i32> [[C]] to <3 x i8>
+; ALL-NEXT:    [[D:%.*]] = ashr <3 x i8> [[A:%.*]], <i8 6, i8 2, i8 7>
 ; ALL-NEXT:    ret <3 x i8> [[D]]
 ;
   %B = sext <3 x i8> %A to <3 x i32>


        


More information about the llvm-commits mailing list