[llvm] 08056e1 - [InstCombine] Generalize sadd.sat combine to compute sign bits.

David Green via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 5 08:05:14 PDT 2021


Author: David Green
Date: 2021-11-05T15:05:09Z
New Revision: 08056e188869414f64598b9908a84582b817ccb5

URL: https://github.com/llvm/llvm-project/commit/08056e188869414f64598b9908a84582b817ccb5
DIFF: https://github.com/llvm/llvm-project/commit/08056e188869414f64598b9908a84582b817ccb5.diff

LOG: [InstCombine] Generalize sadd.sat combine to compute sign bits.

There is a combine in instcombine to transform a saturated add/sub into
a saddsat/ssubsat, currently handling inputs which are both sign
extended (https://alive2.llvm.org/ce/z/68qpTn). This can generalize to,
for example ashr of at least the bitwidth (https://alive2.llvm.org/ce/z/4TFyX-
and https://alive2.llvm.org/ce/z/qDWzFs for example). Which means it
generalizes further to "the number of sign bits", needing to be enough
to truncate to the size of the saturate. (An example using `or` for
instance: https://alive2.llvm.org/ce/z/EI_h_A).

So this patch makes use of ComputeNumSignBits (with the newly added
ComputeMinSignedBits) in matchSAddSubSat to generalize the fold to any
inputs with enough sign bits known, truncating the inputs to the new
size of the saturate.

Differential Revision: https://reviews.llvm.org/D112298

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
    llvm/test/Transforms/InstCombine/sadd_sat.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index c19da8a01c26..6ce4dc941c5a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2304,16 +2304,6 @@ Instruction *InstCombinerImpl::matchSAddSubSat(Instruction &MinMax1) {
 
   // Create the new type (which can be a vector type)
   Type *NewTy = Ty->getWithNewBitWidth(NewBitWidth);
-  // Match the two extends from the add/sub
-  Value *A, *B;
-  if(!match(AddSub, m_BinOp(m_SExt(m_Value(A)), m_SExt(m_Value(B)))))
-    return nullptr;
-  // And check the incoming values are of a type smaller than or equal to the
-  // size of the saturation. Otherwise the higher bits can cause 
diff erent
-  // results.
-  if (A->getType()->getScalarSizeInBits() > NewBitWidth ||
-      B->getType()->getScalarSizeInBits() > NewBitWidth)
-    return nullptr;
 
   Intrinsic::ID IntrinsicID;
   if (AddSub->getOpcode() == Instruction::Add)
@@ -2323,10 +2313,16 @@ Instruction *InstCombinerImpl::matchSAddSubSat(Instruction &MinMax1) {
   else
     return nullptr;
 
+  // The two operands of the add/sub must be nsw-truncatable to the NewTy. This
+  // is usually achieved via a sext from a smaller type.
+  if (ComputeMinSignedBits(AddSub->getOperand(0), 0, AddSub) > NewBitWidth ||
+      ComputeMinSignedBits(AddSub->getOperand(1), 0, AddSub) > NewBitWidth)
+    return nullptr;
+
   // Finally create and return the sat intrinsic, truncated to the new type
   Function *F = Intrinsic::getDeclaration(MinMax1.getModule(), IntrinsicID, NewTy);
-  Value *AT = Builder.CreateSExt(A, NewTy);
-  Value *BT = Builder.CreateSExt(B, NewTy);
+  Value *AT = Builder.CreateTrunc(AddSub->getOperand(0), NewTy);
+  Value *BT = Builder.CreateTrunc(AddSub->getOperand(1), NewTy);
   Value *Sat = Builder.CreateCall(F, {AT, BT});
   return CastInst::Create(Instruction::SExt, Sat, Ty);
 }

diff  --git a/llvm/test/Transforms/InstCombine/sadd_sat.ll b/llvm/test/Transforms/InstCombine/sadd_sat.ll
index a74ed157ccc8..bf30fdee07cc 100644
--- a/llvm/test/Transforms/InstCombine/sadd_sat.ll
+++ b/llvm/test/Transforms/InstCombine/sadd_sat.ll
@@ -698,13 +698,10 @@ entry:
 define i32 @ashrA(i64 %a, i32 %b) {
 ; CHECK-LABEL: @ashrA(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[CONV:%.*]] = ashr i64 [[A:%.*]], 32
-; CHECK-NEXT:    [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64
-; CHECK-NEXT:    [[ADD:%.*]] = add nsw i64 [[CONV]], [[CONV1]]
-; CHECK-NEXT:    [[SPEC_STORE_SELECT:%.*]] = call i64 @llvm.smin.i64(i64 [[ADD]], i64 2147483647)
-; CHECK-NEXT:    [[SPEC_STORE_SELECT8:%.*]] = call i64 @llvm.smax.i64(i64 [[SPEC_STORE_SELECT]], i64 -2147483648)
-; CHECK-NEXT:    [[CONV7:%.*]] = trunc i64 [[SPEC_STORE_SELECT8]] to i32
-; CHECK-NEXT:    ret i32 [[CONV7]]
+; CHECK-NEXT:    [[TMP0:%.*]] = lshr i64 [[A:%.*]], 32
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i64 [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP1]], i32 [[B:%.*]])
+; CHECK-NEXT:    ret i32 [[TMP2]]
 ;
 entry:
   %conv = ashr i64 %a, 32
@@ -719,15 +716,10 @@ entry:
 define i32 @ashrB(i32 %a, i64 %b) {
 ; CHECK-LABEL: @ashrB(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[A:%.*]] to i64
-; CHECK-NEXT:    [[CONV1:%.*]] = ashr i64 [[B:%.*]], 32
-; CHECK-NEXT:    [[ADD:%.*]] = add nsw i64 [[CONV1]], [[CONV]]
-; CHECK-NEXT:    [[TMP0:%.*]] = icmp sgt i64 [[ADD]], -2147483648
-; CHECK-NEXT:    [[SPEC_STORE_SELECT:%.*]] = select i1 [[TMP0]], i64 [[ADD]], i64 -2147483648
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt i64 [[SPEC_STORE_SELECT]], 2147483647
-; CHECK-NEXT:    [[SPEC_STORE_SELECT8:%.*]] = select i1 [[TMP1]], i64 [[SPEC_STORE_SELECT]], i64 2147483647
-; CHECK-NEXT:    [[CONV7:%.*]] = trunc i64 [[SPEC_STORE_SELECT8]] to i32
-; CHECK-NEXT:    ret i32 [[CONV7]]
+; CHECK-NEXT:    [[TMP0:%.*]] = lshr i64 [[B:%.*]], 32
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i64 [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP1]], i32 [[A:%.*]])
+; CHECK-NEXT:    ret i32 [[TMP2]]
 ;
 entry:
   %conv = sext i32 %a to i64
@@ -744,15 +736,12 @@ entry:
 define i32 @ashrAB(i64 %a, i64 %b) {
 ; CHECK-LABEL: @ashrAB(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[CONV:%.*]] = ashr i64 [[A:%.*]], 32
-; CHECK-NEXT:    [[CONV1:%.*]] = ashr i64 [[B:%.*]], 32
-; CHECK-NEXT:    [[ADD:%.*]] = add nsw i64 [[CONV1]], [[CONV]]
-; CHECK-NEXT:    [[TMP0:%.*]] = icmp sgt i64 [[ADD]], -2147483648
-; CHECK-NEXT:    [[SPEC_STORE_SELECT:%.*]] = select i1 [[TMP0]], i64 [[ADD]], i64 -2147483648
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt i64 [[SPEC_STORE_SELECT]], 2147483647
-; CHECK-NEXT:    [[SPEC_STORE_SELECT8:%.*]] = select i1 [[TMP1]], i64 [[SPEC_STORE_SELECT]], i64 2147483647
-; CHECK-NEXT:    [[CONV7:%.*]] = trunc i64 [[SPEC_STORE_SELECT8]] to i32
-; CHECK-NEXT:    ret i32 [[CONV7]]
+; CHECK-NEXT:    [[TMP0:%.*]] = lshr i64 [[A:%.*]], 32
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i64 [[B:%.*]], 32
+; CHECK-NEXT:    [[TMP2:%.*]] = trunc i64 [[TMP1]] to i32
+; CHECK-NEXT:    [[TMP3:%.*]] = trunc i64 [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP4:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP2]], i32 [[TMP3]])
+; CHECK-NEXT:    ret i32 [[TMP4]]
 ;
 entry:
   %conv = ashr i64 %a, 32
@@ -795,14 +784,9 @@ define i32 @ashrA33(i64 %a, i32 %b) {
 ; CHECK-LABEL: @ashrA33(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CONV:%.*]] = ashr i64 [[A:%.*]], 33
-; CHECK-NEXT:    [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64
-; CHECK-NEXT:    [[ADD:%.*]] = add nsw i64 [[CONV]], [[CONV1]]
-; CHECK-NEXT:    [[TMP0:%.*]] = icmp sgt i64 [[ADD]], -2147483648
-; CHECK-NEXT:    [[SPEC_STORE_SELECT:%.*]] = select i1 [[TMP0]], i64 [[ADD]], i64 -2147483648
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt i64 [[SPEC_STORE_SELECT]], 2147483647
-; CHECK-NEXT:    [[SPEC_STORE_SELECT8:%.*]] = select i1 [[TMP1]], i64 [[SPEC_STORE_SELECT]], i64 2147483647
-; CHECK-NEXT:    [[CONV7:%.*]] = trunc i64 [[SPEC_STORE_SELECT8]] to i32
-; CHECK-NEXT:    ret i32 [[CONV7]]
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc i64 [[CONV]] to i32
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP0]], i32 [[B:%.*]])
+; CHECK-NEXT:    ret i32 [[TMP1]]
 ;
 entry:
   %conv = ashr i64 %a, 33
@@ -844,15 +828,10 @@ entry:
 define <2 x i8> @ashrv2i8_s(<2 x i16> %a, <2 x i8> %b) {
 ; CHECK-LABEL: @ashrv2i8_s(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[CONV:%.*]] = ashr <2 x i16> [[A:%.*]], <i16 8, i16 8>
-; CHECK-NEXT:    [[CONV1:%.*]] = sext <2 x i8> [[B:%.*]] to <2 x i16>
-; CHECK-NEXT:    [[ADD:%.*]] = add nsw <2 x i16> [[CONV]], [[CONV1]]
-; CHECK-NEXT:    [[TMP0:%.*]] = icmp sgt <2 x i16> [[ADD]], <i16 -128, i16 -128>
-; CHECK-NEXT:    [[SPEC_STORE_SELECT:%.*]] = select <2 x i1> [[TMP0]], <2 x i16> [[ADD]], <2 x i16> <i16 -128, i16 -128>
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt <2 x i16> [[SPEC_STORE_SELECT]], <i16 127, i16 127>
-; CHECK-NEXT:    [[SPEC_STORE_SELECT8:%.*]] = select <2 x i1> [[TMP1]], <2 x i16> [[SPEC_STORE_SELECT]], <2 x i16> <i16 127, i16 127>
-; CHECK-NEXT:    [[CONV7:%.*]] = trunc <2 x i16> [[SPEC_STORE_SELECT8]] to <2 x i8>
-; CHECK-NEXT:    ret <2 x i8> [[CONV7]]
+; CHECK-NEXT:    [[TMP0:%.*]] = lshr <2 x i16> [[A:%.*]], <i16 8, i16 8>
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc <2 x i16> [[TMP0]] to <2 x i8>
+; CHECK-NEXT:    [[TMP2:%.*]] = call <2 x i8> @llvm.sadd.sat.v2i8(<2 x i8> [[TMP1]], <2 x i8> [[B:%.*]])
+; CHECK-NEXT:    ret <2 x i8> [[TMP2]]
 ;
 entry:
   %conv = ashr <2 x i16> %a, <i16 8, i16 8>
@@ -868,13 +847,10 @@ entry:
 
 define i16 @or(i8 %X, i16 %Y) {
 ; CHECK-LABEL: @or(
-; CHECK-NEXT:    [[CONV10:%.*]] = sext i8 [[X:%.*]] to i16
-; CHECK-NEXT:    [[CONV14:%.*]] = or i16 [[Y:%.*]], -16
-; CHECK-NEXT:    [[SUB:%.*]] = sub nsw i16 [[CONV10]], [[CONV14]]
-; CHECK-NEXT:    [[L9:%.*]] = icmp sgt i16 [[SUB]], -128
-; CHECK-NEXT:    [[L10:%.*]] = select i1 [[L9]], i16 [[SUB]], i16 -128
-; CHECK-NEXT:    [[L11:%.*]] = icmp slt i16 [[L10]], 127
-; CHECK-NEXT:    [[L12:%.*]] = select i1 [[L11]], i16 [[L10]], i16 127
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i16 [[Y:%.*]] to i8
+; CHECK-NEXT:    [[TMP2:%.*]] = or i8 [[TMP1]], -16
+; CHECK-NEXT:    [[TMP3:%.*]] = call i8 @llvm.ssub.sat.i8(i8 [[X:%.*]], i8 [[TMP2]])
+; CHECK-NEXT:    [[L12:%.*]] = sext i8 [[TMP3]] to i16
 ; CHECK-NEXT:    ret i16 [[L12]]
 ;
   %conv10 = sext i8 %X to i16


        


More information about the llvm-commits mailing list